#! /usr/bin/env python

from __future__ import print_function
import logging, os, sys
import argparse

## Base paths etc. for set and index file downloading
CVMFSBASE = "/cvmfs/sft.cern.ch/lcg/external/lhapdfsets/current/"
URLBASE = "http://lhapdfsets.web.cern.ch/lhapdfsets/current/"
INDEX_FILENAME = "pdfsets.index"


class SetInfo(object):
    """Stores PDF metadata: name, version, ID code."""

    def __init__(self, name, id_code, version):
        self.name = name
        self.id_code = id_code
        self.version = version

    def __eq__(self, other):
        if isinstance(other, SetInfo):
            return self.name == other.name
        else:
            return self.name == other

    def __ne__(self, other):
        return not self == other

    def __repr__(self):
        return self.name


def get_reference_list(filepath):
    """Reads reference file and returns list of SetInfo objects.

    The reference file is space-delimited, with columns:
    id_code version name
    """
    database = []
    try:
        import csv
        csv_file = open(filepath, "r")
        logging.debug("Reading %s" % filepath)
        reader = csv.reader(csv_file, delimiter=" ", skipinitialspace=True, strict=True)
        for row in reader:
            # <= 6.0.5
            if len(row) == 2:
                id_code, name, version = int(row[0]), str(row[1]), None
            # >= 6.1.0
            elif len(row) == 3:
                id_code, name, version = int(row[0]), str(row[1]), int(row[2])
            else:
                raise ValueError
            database.append(SetInfo(name, id_code, version))
    except IOError:
        logging.error("Could not open %s" % filepath)
    except (ValueError, csv.Error):
        logging.error("Corrupted file on line %d: %s" % (reader.line_num, filepath))
        csv_file.close()
        database = []
    else:
        csv_file.close()
    return database


def get_installed_list(_=None):
    """Returns a list of SetInfo objects representing installed PDF sets.
    """
    import lhapdf
    database = []
    setnames = lhapdf.availablePDFSets()
    for sn in setnames:
        pdfset = lhapdf.getPDFSet(sn)
        database.append(SetInfo(sn, pdfset.lhapdfID, pdfset.dataversion))
    return database


# TODO: Move this into the Python module to allow Python-scripted downloading?
def download_url(source, dest_dir, dryrun=False):
    """Download a file from a URL or POSIX path source to the destination directory."""

    if not os.path.isdir(os.path.abspath(dest_dir)):
        logging.info("Creating directory %s" % dest_dir)
        os.makedirs(dest_dir)
    dest_filepath = os.path.join(dest_dir, os.path.basename(source))

    # Decide whether to copy or download
    if source.startswith("/") or source.startswith("file://"):  # POSIX
        if source.startswith("file://"):
            source = source[len("file://"):]
        logging.debug("Downloading from %s" % source)
        logging.debug("Downloading to %s" % dest_filepath)
        try:
            file_size = os.stat(source).st_size
            if dryrun:
                logging.info("%s [%s]" % (os.path.basename(source), convertBytes(file_size)))
                return False
            import shutil
            shutil.copy(source, dest_filepath)
        except:
            logging.debug("Unable to download %s" % source)
            return False

    else:  # URL
        url = source
        try:
            import urllib.request as urllib
        except ImportError:
            import urllib2 as urllib
        try:
            u = urllib.urlopen(url)
            content_length = u.info().get("Content-Length", 0)
            if isinstance(content_length, list):
                file_size = int(content_length[0]) if content_length else 0
            else:
                file_size = int(content_length)
        except urllib.URLError:
            e = sys.exc_info()[1]
            logging.debug("Unable to download %s" % url)
            return False

        logging.debug("Downloading from %s" % url)
        logging.debug("Downloading to %s" % dest_filepath)
        if dryrun:
            if file_size:
                logging.info("%s [%s]" % (os.path.basename(url), convertBytes(file_size)))
            else:
                logging.info("%s" % os.path.basename(url))
            return False

        try:
            dest_file = open(dest_filepath, "wb")
        except IOError:
            logging.error("Could not write to %s" % dest_filepath)
            return False
        try:
            try:
                file_size_dl = 0
                buffer_size = 8192
                while True:
                    buffer = u.read(buffer_size)
                    if not buffer: break

                    file_size_dl += len(buffer)
                    dest_file.write(buffer)

                    status = chr(13) + "%s: " % os.path.basename(url)
                    status += r"%s" % convertBytes(file_size_dl).rjust(10)
                    if file_size:
                        status += r"[%3.1f%%]" % (file_size_dl * 100. / file_size)
                    sys.stdout.write(status + " ")
            except urllib.URLError:
                e = sys.exc_info()[1]
                logging.error("Error during download: ", e.reason)
                return False
            except KeyboardInterrupt:
                logging.error("Download halted by user")
                return False
        finally:
            dest_file.close()
            print("")

    return True


def extract_tarball(tar_filename, dest_dir, keep_tarball):
    """Extracts a tarball to the destination directory."""

    tarpath = os.path.join(dest_dir, tar_filename)
    try:
        import tarfile
        tar_file = tarfile.open(tarpath, "r:gz")
        tar_file.extractall(dest_dir)
        tar_file.close()
    except:
        logging.error("Unable to extract %s to %s" % (tar_filename, dest_dir))
    if not keep_tarball:
        try:
            os.remove(tarpath)
        except:
            logging.error("Unable to remove %s after expansion" % tar_filename)


def convertBytes(size, nDecimalPoints=1):
    units = ("B", "KB", "MB", "GB")
    import math
    i = int(math.floor(math.log(size, 1024)))
    p = math.pow(1024, i)
    s = round(size / p, nDecimalPoints)
    if s > 0:
        return "%s %s" % (s, units[i])
    else:
        return "0 B"


def download_file(sources, filename, dest_dir, dryrun=False):
    sources_tried = []
    for source in sources:
        url = source + filename
        if download_url(url, dest_dir, dryrun):
            return True
        sources_tried.append(url)
    logging.error("Unable to download from any of %s" % sources_tried)
    return False


def globfilt(pdfs, patterns):
    """Unix-style pattern matching of arguments"""
    rtn = []
    if not patterns:
        return pdfs
    from fnmatch import fnmatch
    for pdf in pdfs:
        for pattern in patterns:
            if fnmatch(pdf, pattern):
                rtn.append(pdf)
    return rtn




class CommandHandler(object):
    """\
A program for managing LHAPDF parton distribution function data files.

The main sub-commands that can be used are:
  - list|ls:     list available PDF sets, optionally filtered and/or categorised by status
  - show:        show metadata details of specified PDF sets
  - update:      download and install a new PDF set index file
  - install|get: download and install new PDF set data files
  - upgrade:     download and install newer replacement PDF set data files where available
"""

    def __init__(self):

        ## Load settings from Python module
        try:
            import lhapdf
            DATADIR = lhapdf.paths()[0]
            VERSION = lhapdf.__version__
        except ImportError:
            DATADIR = None
            VERSION = None

        ## Parse the command line
        ap = argparse.ArgumentParser(description=self.__doc__, formatter_class=argparse.RawTextHelpFormatter)
        ap.add_argument("COMMAND", metavar="COMMAND [suboptions]", help="Subcommand to run")
        ap.add_argument("--listdir", default=DATADIR, dest="LISTDIR",
                        help="Directory containing the lhapdf.index list file [default: %(default)s]")
        ap.add_argument("--pdfdir", default=DATADIR, dest="PDFDIR",
                        help="Directory for installation of PDF set data [default: %(default)s]")
        ap.add_argument("--source", default=[CVMFSBASE, URLBASE], action="append", dest="SOURCES",
                        help="Prepend a path or URL to be used as a source of data files [default: %(default)s]")
        ap.add_argument("-q", "--quiet", help="Suppress normal messages", dest="VERBOSITY", action="store_const",
                        const=logging.ERROR, default=logging.INFO)
        ap.add_argument("-v", "--verbose", help="Output debug messages", dest="VERBOSITY", action="store_const",
                        const=logging.DEBUG, default=logging.INFO)
        if VERSION:
            ap.add_argument("--version", action="version", version=VERSION)
        self.mainargs, otherargs = ap.parse_known_args()

        ## Apply verbosity settings
        logging.basicConfig(format="%(message)s", level=self.mainargs.VERBOSITY)

        ## Re-order the sources list since argparse doesn't have a "prepend" action!
        self.mainargs.SOURCES = self.mainargs.SOURCES[3:] + self.mainargs.SOURCES[:3]

        ## Check for a command
        if not hasattr(self, self.mainargs.COMMAND):
            print("Unrecognized command")
            ap.print_help()
            exit(2)

        ## Use dispatch pattern to invoke method with same name:
        getattr(self, self.mainargs.COMMAND)(otherargs)


    def _scanpdfs(self):
        self.master_list, self.installed = {}, {}

        ## Return empty lists if relevant search directories are not known
        if self.mainargs.LISTDIR is None or self.mainargs.PDFDIR is None:
            return

        ## List and install commands require us to build lists of reference and installed PDFs
        indexpath = os.path.join(self.mainargs.LISTDIR, INDEX_FILENAME)
        logging.debug("Index file = " + indexpath)
        for pdf in get_reference_list(indexpath):
            self.master_list[pdf.name] = pdf
        for pdf in get_installed_list(self.mainargs.PDFDIR):
            self.installed[pdf.name] = pdf

        ## Check installation status of all PDFs
        for pdf in self.master_list.keys():
            self.master_list[pdf].installed = pdf in self.installed
            if pdf not in self.installed or self.installed[pdf].version is None or self.master_list[
                pdf].version is None:
                self.master_list[pdf].outdated = False
            else:
                self.master_list[pdf].outdated = self.installed[pdf].version < self.master_list[pdf].version


    def list(self, otherargs):
        """List all standard PDF sets, or search using a Unix-style pattern.
        (by default lists all sets available for download; use --installed or --outdated to explore those installed on the current system)"""
        ap = argparse.ArgumentParser(description=__doc__, usage="%(prog)s list [options] [pattern...]")
        ap.add_argument("PATTERNS", nargs="*", help="patterns to match PDF sets against")
        ag = ap.add_mutually_exclusive_group()
        ag.add_argument("--installed", dest="INSTALLED", action="store_true", help="list installed PDF sets")
        ag.add_argument("--outdated", dest="OUTDATED", action="store_true",
                        help="list installed, but outdated, PDF sets")
        ap.add_argument("--codes", dest="CODES", action="store_true", help="additionally show ID codes")
        subargs = ap.parse_args(otherargs)
        # if subargs.INSTALLED and subargs.OUTDATED:
        #     ap.error("Options '--installed' and '--outdated' are mutually exclusive")

        ## Scan the current PDF collection
        self._scanpdfs()

        ## Filter PDFs on optional patterns and status
        # pdfs = globfilt(self.master_list.keys(), subargs.PATTERNS)
        pdfs = []
        for pdf in self.master_list.keys():
            if globfilt([pdf, str(self.master_list[pdf].id_code)], subargs.PATTERNS):
                pdfs.append(pdf)
        if subargs.INSTALLED:
            pdfs = [pdf for pdf in pdfs if self.master_list[pdf].installed]
        if subargs.OUTDATED:
            pdfs = [pdf for pdf in pdfs if self.master_list[pdf].outdated]

        ## Display
        # TODO: (optional) ordering by LHAPDF ID code
        for pdf in sorted(pdfs):
            if subargs.CODES:
                print("%d  %s" % (self.master_list[pdf].id_code, pdf))
            else:
                print(pdf)
        sys.exit(0)


    def ls(self, otherargs):
        self.list(otherargs)


    def show(self, otherargs):
        """Show details for installed PDF sets matching Unix-style patterns."""
        ap = argparse.ArgumentParser(description=__doc__, usage="%(prog)s show [options] pattern...")
        ap.add_argument("PATTERNS", nargs="+", help="patterns to match PDF sets against")
        subargs = ap.parse_args(otherargs)

        ## Scan the current PDF collection
        self._scanpdfs()

        ## Filter PDFs on optional patterns and status
        # pdfs = globfilt(self.installed.keys(), subargs.PATTERNS)
        pdfs = []
        for pdf in self.master_list.keys():
            if globfilt([pdf, str(self.master_list[pdf].id_code)], subargs.PATTERNS):
                pdfs.append(pdf)

        ## Display
        strs = []
        for pdf in sorted(pdfs):
            import lhapdf
            ps = lhapdf.getPDFSet(pdf)
            out = ""
            out += ps.name + "\n" + "=" * len(pdf) + "\n"
            out += "LHAPDF ID: {:d}\n".format(ps.lhapdfID)
            out += "Version: {:d}\n".format(ps.dataversion)
            out += ps.description + "\n"
            out += "Number of members: {:d}\n".format(ps.size)
            out += "Error type: {:s}\n".format(ps.errorType)
            strs.append(out)

        print("\n\n".join(strs))
        sys.exit(0)


    def update(self, otherargs):
        """Update the list of available PDF sets."""
        ap = argparse.ArgumentParser(description=__doc__, usage="%(prog)s update")
        updateargs = ap.parse_args(otherargs)
        if self.mainargs.LISTDIR is not None:
            download_file(self.mainargs.SOURCES, INDEX_FILENAME, self.mainargs.LISTDIR)
        else:
            print("PDF index file location not known: can't update")
        sys.exit(0)


    def install(self, otherargs):
        """Download and unpack a list of PDFs, or those matching a Unix-style pattern."""
        ap = argparse.ArgumentParser(description=__doc__, usage="%(prog)s install [options] pattern...")
        ap.add_argument("PATTERNS", nargs="*", help="patterns to match PDF sets against")
        ap.add_argument("--dryrun", dest="DRYRUN", action="store_true", help="Do not download sets")
        ap.add_argument("--upgrade", dest="UPGRADE", action="store_true", help="Force reinstall (used to upgrade)")
        ap.add_argument("--keep", dest="KEEP_TARBALLS", action="store_true", help="Keep the downloaded tarballs")
        subargs = ap.parse_args(otherargs)

        if self.mainargs.PDFDIR is None:
            print("PDF data file location not known: can't upgrade or install")

        ## Scan the current PDF collection
        self._scanpdfs()

        ## Filter PDFs on optional patterns
        pdfs = globfilt(self.master_list.keys(), subargs.PATTERNS)

        if not pdfs:
            logging.warning("No PDFs known matching patterns: %s" % ", ".join(subargs.PATTERNS))

        for pdf in sorted(pdfs):
            if pdf in self.installed and not subargs.UPGRADE:
                logging.warning("PDF already installed: %s (use --upgrade to force install)" % pdf)
                continue

            if self.master_list[pdf].version == -1:
                logging.warn("PDF %s is unvalidated. You need to download this manually" % pdf)

            tar_filename = pdf + ".tar.gz"
            if download_file(self.mainargs.SOURCES, tar_filename, self.mainargs.PDFDIR, dryrun=subargs.DRYRUN):
                extract_tarball(tar_filename, self.mainargs.PDFDIR, subargs.KEEP_TARBALLS)


    def get(self, otherargs):
        self.install(otherargs)


    def upgrade(self, otherargs):
        """Reinstall all PDF sets considered outdated by the local reference list"""
        ap = argparse.ArgumentParser(description=__doc__, usage="%(prog)s upgrade")
        ap.add_argument("PATTERNS", nargs="*", help="patterns to match PDF sets against")
        ap.add_argument("--dryrun", dest="DRYRUN", action="store_true", help="Do not download sets")
        ap.add_argument("--keep", dest="KEEP_TARBALLS", action="store_true", help="Keep the downloaded tarballs")
        subargs = ap.parse_args(otherargs)

        if self.mainargs.PDFDIR is None:
            print("PDF data file location not known: can't upgrade or install")

        ## Scan the current PDF collection
        self._scanpdfs()

        ## Get the PDFs in need of an update
        outdated_pdfs = [pdf for pdf in self.master_list.keys() if self.master_list[pdf].outdated]

        ## Filter PDFs on optional patterns
        upgrade_pdfs = globfilt(outdated_pdfs, subargs.PATTERNS)

        for pdf in upgrade_pdfs:
            tar_filename = pdf + ".tar.gz"
            if download_file(self.mainargs.SOURCES, tar_filename, self.mainargs.PDFDIR, dryrun=subargs.DRYRUN):
                extract_tarball(tar_filename, self.mainargs.PDFDIR, upgrade.KEEP_TARBALLS)


if __name__ == "__main__":
    CommandHandler()