#!/usr/bin/env python

import copy
import md5
import os
import sys
import cStringIO as StringIO
import re
import struct
import zipfile

PATHS = {"libdir": "/usr/lib/gcj",
         "gcj":    "/usr/bin/gcj",
         "dbtool": "/usr/bin/gcj-dbtool"}

GCJFLAGS = ["-fPIC", "-findirect-dispatch", "-fjni"]
LDFLAGS = ["-Wl,-Bsymbolic"]

MAX_CLASSES_PER_JAR = 1024
MAX_BYTES_PER_JAR = 1048576

ZIPMAGIC, CLASSMAGIC = "PK\x03\x04", "\xca\xfe\xba\xbe"

class Error(Exception):
    pass

def aot_compile(basedir, libdir, exclusions = ()):
    """Search basedir for classes and jarfiles, then generate solibs
    and mappings databases for them all in libdir."""
    dstdir = os.path.join(basedir, libdir.strip(os.sep))
    if not os.path.isdir(dstdir):
        os.makedirs(dstdir)
    jobs = weed_jobs(find_jobs(basedir, exclusions))
    set_basenames(jobs)
    for job in jobs:
        job.compile(dstdir, libdir)

def find_jobs(dir, exclusions = ()):
    """Scan a directory and find things to compile: jarfiles (zips,
    wars, ears, rars, etc: we go by magic rather than file extension)
    and directories of classes."""
    def visit((classes, zips), dir, items):
        for item in items:
            path = os.path.join(dir, item)
            if os.path.islink(path) or not os.path.isfile(path):
                continue
            magic = open(path, "r").read(4)
            if magic == ZIPMAGIC:
                zips.append(path)
            elif magic == CLASSMAGIC:
                classes.append(path)
    classes, paths = [], []
    os.path.walk(dir, visit, (classes, paths))
    # Convert the list of classes into a list of directories
    while classes:
        # XXX this requires the class to be correctly located in its heirachy.
        path = classes[0][:-len(os.sep + classname(classes[0]) + ".class")]
        paths.append(path)
        classes = [cls for cls in classes if not cls.startswith(path)]
    # Handle exclusions.  We're really strict about them because the
    # option is temporary and dead options left in specfiles will
    # hinder its removal.
    for path in exclusions:
        if path in paths:
            paths.remove(path)
        else:
            raise Error, "%s: path does not exist or is not a job" % path
    # Build the list of jobs
    jobs = []
    paths.sort()
    for path in paths:
        if os.path.isfile(path):
            job = JarJob(path)
        else:
            job = DirJob(path)
        if len(job.classes):
            jobs.append(job)
    return jobs

class Job:
    """A collection of classes that will be compiled as a unit."""
    
    def __init__(self, path):
        self.path, self.classes = path, {}

    def addClass(self, bytes):
        self.classes[md5.new(bytes).digest()] = bytes
    
    # From Archit Shah:
    #   The implementation and the documentation don't seem to match.
    #  
    #    [a, b].isSubsetOf([a]) => True
    #  
    #   Identical copies of all classes this collection do not exist
    #   in the other. I think the method should be named isSupersetOf
    #   and the documentation should swap uses of "this" and "other"
    #
    # XXX think about this when I've had more sleep...

    def isSubsetOf(self, other):
        """Returns True if identical copies of all classes in this
        collection exist in the other."""
        for item in other.classes.keys():
            if not self.classes.has_key(item):
                return False
        return True

    def writeSources(self, dir):
        """Generate jarfiles that can be native compiled by gcj.  In
        the majority of cases this is not necessary -- the collection
        will have come from a jarfile which will be equivalent to the
        one we generate -- but this only happens _if_ the collection
        was a jarfile and _if_ the jarfile isn't too big and _if_ the
        jarfile has the correct extension and _if_ all classes are
        correctly named and _if_ the jarfile has no embedded jarfiles.
        Fitting a special case around all these conditions is tricky
        to say the least."""
        names = {}
        for hash, bytes in self.classes.items():
            name = classname(bytes)
            if not names.has_key(name):
                names[name] = []
            names[name].append(hash)
        names = names.items()
        # we sort by name in a simplistic attempt to keep related
        # classes together so inter-class optimisation can happen.
        names.sort()
        paths, jar = [], None
        for name, hashes in names:
            for hash in hashes:
                if (jar is None
                    or count >= MAX_CLASSES_PER_JAR
                    or bytes >= MAX_BYTES_PER_JAR):
                    
                    if jar is not None:
                        jar.close()
                    path = os.path.join(dir, "%s.%d.jar" % (
                        os.path.basename(self.path), len(paths) + 1))
                    paths.append(path)
                    jar = zipfile.ZipFile(path, "w", zipfile.ZIP_STORED)
                    count = bytes = 0
                jar.writestr(
                    zipfile.ZipInfo("%s.class" % name), self.classes[hash])
                count += 1
                bytes += len(self.classes[hash])
        jar.close()
        return paths

    def compile(self, dir, libdir):
        """Generate the shared library and class mapping for one jarfile.
        If the shared library already exists then it will not be
        overwritten.  This is to allow optimizer failures and the like to
        be worked around."""
        soname = os.path.join(dir, self.basename + ".so")
        dbname = soname[:soname.rfind(".")] + ".db"
        inst_soname = os.path.join(libdir, os.path.basename(soname))
        cleanup = []
        # prepare
        sources = self.writeSources(dir)
        cleanup.extend(sources)
        system([PATHS["dbtool"], "-n", dbname, "64"])
        # compile and link
        if os.path.exists(soname):
            warn("not recreating %s" % soname)
        else:
            if len(sources) == 1:
                system([PATHS["gcj"], "-shared"] +
                       GCJFLAGS + LDFLAGS +
                       [sources[0], "-o", soname])
            else:
                objects = []
                for source in sources:
                    object = os.path.join(dir, os.path.basename(source) + ".o")
                    system([PATHS["gcj"], "-c"] +
                           GCJFLAGS +
                           [source, "-o", object])
                    objects.append(object)
                    cleanup.append(object)
                system([PATHS["gcj"], "-shared"] +
                       GCJFLAGS + LDFLAGS +
                       objects + ["-o", soname])
        # dbtool
        for source in sources:
            system([PATHS["dbtool"], "-f", dbname, source, inst_soname])
        # clean up
        for item in cleanup:
            os.unlink(item)

class JarJob(Job):
    """A Job whose origin was a jarfile."""

    def __init__(self, path):
        Job.__init__(self, path)
        self._walk(zipfile.ZipFile(path, "r"))

    def _walk(self, zf):
        for name in zf.namelist():
            bytes = zf.read(name)
            if bytes.startswith(ZIPMAGIC):
                self._walk(zipfile.ZipFile(StringIO.StringIO(bytes)))
            elif bytes.startswith(CLASSMAGIC):
                self.addClass(bytes)

class DirJob(Job):
    """A Job whose origin was a directory of classfiles."""

    def __init__(self, path):
        Job.__init__(self, path)
        os.path.walk(path, DirJob._visit, self)

    def _visit(self, dir, items):
        for item in items:
            path = os.path.join(dir, item)
            if os.path.islink(path) or not os.path.isfile(path):
                continue
            fp = open(path, "r")
            magic = fp.read(4)
            if magic == CLASSMAGIC:
                self.addClass(magic + fp.read())
    
def weed_jobs(jobs):
    """Remove any jarfiles that are completely contained within
    another.  This is more common than you'd think, and we only
    need one nativified copy of each class after all."""
    jobs = copy.copy(jobs)
    while True:
        for job1 in jobs:
            for job2 in jobs:
                if job1 is job2:
                    continue
                if job1.isSubsetOf(job2):
                    msg = "subsetted %s" % job2.path
                    if job2.isSubsetOf(job1):
                        if (isinstance(job1, DirJob) and
                            isinstance(job2, JarJob)):
                            # In the braindead case where a package
                            # contains an expanded copy of a jarfile
                            # the jarfile takes precedence.
                            continue
                        msg += " (identical)"
                    warn(msg)
                    jobs.remove(job2)
                    break
            else:
                continue
            break
        else:
            break
        continue
    return jobs

def set_basenames(jobs):
    """Ensure that each jarfile has a different basename."""
    names = {}
    for job in jobs:
        name = os.path.basename(job.path)
        if not names.has_key(name):
            names[name] = []
        names[name].append(job)
    for name, set in names.items():
        if len(set) == 1:
            set[0].basename = name
            continue
        # prefix the jar filenames to make them unique
        # XXX will not work in most cases -- needs generalising
        set = [(job.path.split(os.sep), job) for job in set]
        minlen = min([len(bits) for bits, job in set])
        set = [(bits[-minlen:], job) for bits, job in set]
        bits = apply(zip, [bits for bits, job in set])
        while True:
            row = bits[-2]
            for bit in row[1:]:
                if bit != row[0]:
                    break
            else:
                del bits[-2]
                continue
            break
        set = zip(
            ["_".join(name) for name in apply(zip, bits[-2:])],
            [job for bits, job in set])
        for name, job in set:
            warn("building %s as %s" % (job.path, name))
            job.basename = name
    # XXX keep this check until we're properly general
    names = {}
    for job in jobs:
        name = job.basename
        if names.has_key(name):
            raise Error, "%s: duplicate jobname" % name
        names[name] = 1

def system(command):
    """Execute a command."""
    prefix = os.environ.get("PS4", "+ ")
    prefix = prefix[0] + prefix
    print >>sys.stderr, prefix + " ".join(command)

    status = os.spawnv(os.P_WAIT, command[0], command)
    if status > 0:
        raise Error, "%s exited with code %d" % (command[0], status)
    elif status < 0:
        raise Error, "%s killed by signal %d" % (command[0], -status)

def warn(msg):
    """Print a warning message."""
    print >>sys.stderr, "%s: warning: %s" % (
        os.path.basename(sys.argv[0]), msg)

# XXX everything above this should be in a library, aotcompile.py or
# something, everything from here to if __name__ == "__main__" in
# classlib.py, and the last little bit in /usr/bin/aot-compile-rpm.
# But all that's a little fiddly to do while aot-compile-rpm is
# managed by alternatives :-/

def classname(bytes):
    klass = Class(bytes)
    return klass.constants[klass.constants[klass.name][1]][1]

class Class:
    """A cut-down version of Katana's class file parser, enough to
    extract the class's name but no more.  The unabridged version
    lives at http://inauspicious.org/files/scripts/classfile.py"""
    
    def __init__(self, arg):
        if hasattr(arg, "read"):
            self.fp = arg
        elif type(arg) == type(""):
            if arg.startswith(CLASSMAGIC):
                self.fp = StringIO.StringIO(arg)
            else:
                self.fp = open(arg, "r")
        else:
            raise TypeError, type(arg)

        magic = self._read_int()
        minor, major = self._read(">HH")

        self._read_constants_pool()

        access_flags = self._read_short()
        self.name = self._read_reference_Class()

        del self.fp

    def _read_constants_pool(self):
        self.constants = {}
        skip = False
        for i in xrange(1, self._read_short()):
            if skip:
                skip = False
                continue
            tag = {
                1: "Utf8", 3: "Integer", 4: "Float", 5: "Long",
                6: "Double", 7: "Class", 8: "String", 9: "Fieldref",
                10: "Methodref", 11: "InterfaceMethodref",
                12: "NameAndType"}[self._read_byte()]
            skip = tag in ("Long", "Double") # crack crack crack!
            self.constants[i] = (tag, getattr(self, "_read_constant_" + tag)())

    def _read_reference_Utf8(self):
        return self._read_references("Utf8")[0]

    def _read_reference_Class(self):
        return self._read_references("Class")[0]

    def _read_reference_Class_NameAndType(self):
        return self._read_references("Class", "NameAndType")

    def _read_references(self, *args):
        result = []
        for arg in args:
            index = self._read_short()
            result.append(index)
        return result

    def _read_constant_Utf8(self):
        return self.fp.read(self._read_short())

    def _read_constant_Integer(self):
        return self._read_int()

    def _read_constant_Float(self):
        return self._read(">f")[0]

    def _read_constant_Long(self):
        return self._read(">q")[0]

    def _read_constant_Double(self):
        return self._read(">d")[0]

    _read_constant_Class = _read_reference_Utf8
    _read_constant_String = _read_reference_Utf8
    _read_constant_Fieldref = _read_reference_Class_NameAndType
    _read_constant_Methodref = _read_reference_Class_NameAndType
    _read_constant_InterfaceMethodref = _read_reference_Class_NameAndType

    def _read_constant_NameAndType(self):
        return self._read_reference_Utf8(), self._read_reference_Utf8()

    def _read_int(self):
        # XXX how else to read 32 bits on a 64-bit box?
        h, l = map(long, self._read(">HH"))
        return (h << 16) + l

    def _read_short(self):
        return self._read(">H")[0]

    def _read_byte(self):
        return self._read("B")[0]

    def _read(self, fmt):
        return struct.unpack(fmt, self.fp.read(struct.calcsize(fmt)))

if __name__ == "__main__":
    try:
        name = os.environ.get("RPM_PACKAGE_NAME")
        if name is None:
            raise Error, "this script is designed for use in rpm specfiles"
        arch = os.environ.get("RPM_ARCH")
        if arch == "noarch":
            raise Error, "cannot be used on noarch packages"
        buildroot = os.environ.get("RPM_BUILD_ROOT")
        if buildroot in (None, "/"):
            raise Error, "bad $RPM_BUILD_ROOT"
        p = re.compile('-O[2-9]+');
        GCJFLAGS = p.sub('-O', os.environ.get("RPM_OPT_FLAGS", "")).split() + GCJFLAGS

        # XXX: This script should not accept options, because having
        # them it cannot be integrated into rpm.  But, gcj cannot
        # build each and every jarfile yet, so we must be able to
        # exclude until it can.
        # XXX --exclude is also used in the jonas rpm to stop
        # everything being made a subset of the mammoth client
        # jarfile. Should adjust the subset checker's bias to
        # favour many small jarfiles over one big one.
        try:
            options, exclusions = sys.argv[1:], []
            while options:
                if options.pop(0) != "--exclude":
                    raise ValueError
                exclusions.append(os.path.join(
                    buildroot, options.pop(0).lstrip(os.sep)))
        except:
            print >>sys.stderr, "usage: %s [--exclude PATH]..." % (
                os.path.basename(sys.argv[0]))
            sys.exit(1)
        
        aot_compile(buildroot, os.path.join(PATHS["libdir"], name), exclusions)
    except Error, e:
        print >>sys.stderr, "%s: error: %s" % (
            os.path.basename(sys.argv[0]), e)
        sys.exit(1)
