#! /usr/bin/env python

"""\
Usage: %(prog)s <PDF1> [<PDF2> [...]]

Plot PDF and alphaS values for the named LHAPDF6 sets.

TODO:
 * Allow user specification of the various PDF/parton line colours and styles
 * Allow user specification of the plot types to be shown
 * Allow user specification of the discrete xs, Qs, and PIDs lists
"""

from __future__ import print_function
import sys, os

import argparse
ap = argparse.ArgumentParser(usage=__doc__)
ap.add_argument("PNAMES", metavar="NAME", nargs="+", help="PDF members to include in the plots")
ap.add_argument("--xfmin", dest="XFMIN", metavar="NUM", help="minimum xf value [default: %(default)s]", type=float,
                default=1e-2)
ap.add_argument("--xmin", dest="XMIN", metavar="NUM", help="minimum x value [default: %(default)s]", type=float,
                default=1e-10)
ap.add_argument("--qmax", dest="QMAX", metavar="NUM", help="maximum Q value in GeV [default: %(default)s]", type=float,
                default=1e4)
ap.add_argument("-f", "--format", dest="FORMAT", metavar="F",
                help="plot file format, i.e. file extension pdf/png/... [default: %(default)s]", default="pdf")
ap.add_argument("--qs", dest="QS", metavar="Q1,Q2,...",
                help="discrete Q values to use on plots vs. x [default: %(default)s]", default="1,10,100,1000,10000")
ap.add_argument("--xs", dest="XS", metavar="X1,X2,...",
                help="discrete x values to use on plots vs. Q [default: %(default)s]", default="1e-5,1e-3,1e-2,1e-1")
ap.add_argument("--pids", dest="PIDS", metavar="ID1,ID2,...",
                help="PID values to use on PDF plots [default: %(default)s]", default="0,1,2,3,4,5,-1,-2,-3,-4,-5")
ap.add_argument("--plots", dest="PLOTS", metavar="PLOT1,PLOT2,...",
                help="plot types to show, default value lists all types [default: %(default)s]",
                default="alphas,xf_x/pid,xf_x/q,xf_q/pid,xf_q/x")
ap.add_argument("--aliases", dest="PALIASES", metavar="PATT/REPL,PATT/REPL,...",
                help="pattern/replacement pairs for PDF names in the plot legend", default=None)
ap.add_argument("-q", "--quiet", dest="VERBOSITY", action="store_const", const=0,
                help="suppress non-essential messages", default=1)
ap.add_argument("-v", "--verbose", dest="VERBOSITY", action="store_const", const=2, help="output debug messages",
                default=1)
args = ap.parse_args()

args.PLOTS = args.PLOTS.upper().split(",")
if not args.PNAMES:
    print(__doc__)
    sys.exit(1)

import matplotlib.pyplot as plt
plt.rcParams["font.family"] = "serif"
# plt.rcParams["font.serif"] = "Computer Modern Roman"
plt.rcParams["text.usetex"] = True

STYLES = ["-", "--", "-.", (0, (5, 2, 1, 2, 1, 2)), ":"]
COLORS = ["red", "blue", "darkgreen", "orange", "purple", "magenta", "gray", "cyan"]
PARTONS = {-5: r"$\bar{b}$", -4: r"$\bar{c}$", -3: r"$\bar{s}$", -2: r"$\bar{u}$", -1: r"$\bar{d}$",
           1: r"$d$", 2: r"$u$", 3: r"$s$", 4: r"$c$", 5: r"$b$", 0: r"$g$"}

## Construct name aliases
PALIASES = {}
if args.PALIASES:
    for pname_palias in args.PALIASES.split(","):
        parts = pname_palias.split("/")
        if len(parts) == 1:
            parts.append("")
        PALIASES[parts[0]] = parts[1]


def palias(pname):
    parts = pname.split("/")
    # print(parts[0], "->", PALIASES.get(parts[0]))
    parts[0] = PALIASES.get(parts[0], parts[0])
    return "/".join(parts)


def tex_str(a):
    return a.replace("_", r"\_").replace("#", r"\#")


def tex_float(f):
    float_str = "{0:.2g}".format(f)
    if "e" in float_str:
        mant, exp = float_str.split("e")
        exp = int(exp)
        return r"{0} \times 10^{{{1}}}".format(mant, exp)
    else:
        return float_str


## Get sampling points in x,Q
from math import log10
import numpy as np
xs = np.logspace(log10(args.XMIN), 0, 100)
qs = np.logspace(0, log10(args.QMAX), 100)
xs_few = [float(x) for x in args.XS.split(",")]  # [1e-5, 1e-3, 1e-2, 1e-1]
qs_few = [float(q) for q in args.QS.split(",")]  # [1, 10, 100, 1000, 10000]
pids = [int(i) for i in args.PIDS.split(",")]  # [0] + range(1,5+1) + [-i for i in range(1,5+1)]
# print(xs_few, qs_few, pids)

## Load PDFs for plotting, indexed by name
import lhapdf
lhapdf.setVerbosity(args.VERBOSITY)
pdfs = {pname: lhapdf.mkPDF(pname) for pname in args.PNAMES}
print()

## Make PDF xf vs. x & Q plots for each parton flavour, and a single alpha_s vs. Q plot
fig = plt.figure()
ax = fig.add_subplot(111)

## alpha_s vs Q plot
if "ALPHAS" in args.PLOTS:
    plt.cla()
    for i, pname in enumerate(args.PNAMES):
        color = COLORS[i % len(COLORS)]
        as_vals = [pdfs[pname].alphasQ(q) for q in qs]
        ax.plot(qs, as_vals, label=tex_str(palias(pname)), color=color, ls="-")
    ax.set_xlabel("$Q$")
    ax.set_ylabel(r"$\alpha_s(Q)$")
    ax.set_ylim(bottom=0)
    ax.set_xscale("log")
    l = ax.legend(loc=0, ncol=2, frameon=False, fontsize="xx-small")
    fname = "alpha_s.{}".format(args.FORMAT)
    if args.VERBOSITY > 0:
        print("Writing plot file", fname)
    fig.savefig(fname)

## xf vs. x plots (per PID)
if "XF_X/PID" in args.PLOTS:
    for pid in pids:
        plt.cla()
        ax.text(0.95, 0.5, "PID={:d}".format(pid), transform=ax.transAxes, ha="right", va="top")
        ax.set_xlabel("$x$")
        ax.set_ylabel("$xf(x,Q)$")
        for i, pname in enumerate(args.PNAMES):
            for j, q in enumerate(qs_few):
                color = COLORS[i % len(COLORS)]
                style = STYLES[j % len(STYLES)]
                xf_vals = [pdfs[pname].xfxQ(pid, x, q) for x in xs]
                title = "{}, $Q={}~\\mathrm{{GeV}}$".format(tex_str(palias(pname)), tex_float(q))
                plt.plot(xs, xf_vals, label=title, color=color, ls=style, lw=1.0)
        ax.set_xscale("log")
        # ax.set_ylim(bottom=0)
        ax.set_ylim(bottom=args.XFMIN)
        ax.set_yscale("log")
        l = ax.legend(loc=0, ncol=2, frameon=False, fontsize="xx-small")
        fname = "pdf_pid{:d}_x.{}".format(pid, args.FORMAT)
        if args.VERBOSITY > 0:
            print("Writing plot file", fname)
        fig.savefig(fname)

## xf vs. x plots (per Q)
if "XF_X/Q" in args.PLOTS:
    for j, q in enumerate(qs_few):
        plt.cla()
        # ax.text(0.95, 0.5, "$Q={}~\\mathrm{{GeV}}$".format(tex_float(q)), transform=ax.transAxes, ha="right", va="top")
        ax.set_xlabel("$x$")
        ax.set_ylabel("$xf(x,Q={}~\\mathrm{{GeV}})$".format(tex_float(q)))
        for pid in pids:
            for i, pname in enumerate(args.PNAMES):
                color = COLORS[pid % len(COLORS)]
                style = STYLES[i % len(STYLES)]
                xf_vals = [pdfs[pname].xfxQ(pid, x, q) for x in xs]
                title = "{}, ID={:d}".format(tex_str(palias(pname)), pid)
                plt.plot(xs, xf_vals, label=title, color=color, ls=style, lw=1.0)
        ax.set_xscale("log")
        # ax.set_ylim(bottom=0)
        ax.set_ylim(bottom=args.XFMIN)
        ax.set_yscale("log")
        l = ax.legend(loc=0, ncol=2, frameon=False, fontsize="xx-small")
        fname = "pdf_q{:d}_x.{}".format(int(q), args.FORMAT)
        if args.VERBOSITY > 0:
            print("Writing plot file", fname)
        fig.savefig(fname)

## xf vs. Q plots (per PID)
if "XF_Q/PID" in args.PLOTS:
    for pid in pids:
        plt.cla()
        ax.text(0.05, 0.7, "PID={:d}".format(pid), transform=ax.transAxes, ha="left", va="center")
        ax.set_xlabel("$Q$")
        ax.set_ylabel("$xf(x,Q)$")
        for i, pname in enumerate(args.PNAMES):
            for j, x in enumerate(xs_few):
                color = COLORS[i % len(COLORS)]
                style = STYLES[j % len(STYLES)]
                xf_vals = [pdfs[pname].xfxQ(pid, x, q) for q in qs]
                title = "{}, $x={}$".format(tex_str(palias(pname)), tex_float(x))
                plt.plot(qs, xf_vals, label=title, color=color, ls=style, lw=1.0)
        ax.set_xscale("log")
        # ax.set_ylim(bottom=0)
        ax.set_ylim(bottom=args.XFMIN)
        ax.set_yscale("log")
        l = ax.legend(loc=0, ncol=2, frameon=False, fontsize="xx-small")
        fname = "pdf_pid{:d}_q.{}".format(pid, args.FORMAT)
        if args.VERBOSITY > 0:
            print("Writing plot file", fname)
        fig.savefig(fname)

## xf vs. Q plots (per x)
if "XF_Q/X" in args.PLOTS:
    for j, x in enumerate(xs_few):
        plt.cla()
        # ax.text(0.95, 0.5, "$x={}$".format(tex_float(x)), transform=ax.transAxes, ha="right", va="bottom")
        ax.set_xlabel("$Q$")
        ax.set_ylabel("$xf(x={},Q)$".format(tex_float(x)))
        for pid in pids:
            for i, pname in enumerate(args.PNAMES):
                color = COLORS[pid % len(COLORS)]
                style = STYLES[i % len(STYLES)]
                xf_vals = [pdfs[pname].xfxQ(pid, x, q) for q in qs]
                title = "{}, ID={:d}".format(tex_str(palias(pname)), pid)
                plt.plot(qs, xf_vals, label=title, color=color, ls=style, lw=1.0)
        ax.set_xscale("log")
        # ax.set_ylim(bottom=0)
        ax.set_ylim(bottom=args.XFMIN)
        ax.set_yscale("log")
        l = ax.legend(loc=0, ncol=2, frameon=False, fontsize="xx-small")
        fname = "pdf_x{:0.2f}_q.{}".format(x, args.FORMAT)
        if args.VERBOSITY > 0:
            print("Writing plot file", fname)
        fig.savefig(fname)

plt.close(fig)
print()