#!

from imp import find_module,load_module
from sys import path as sys_path
from os import path as os_path
from os import makedirs
from argparse import ArgumentParser,ArgumentDefaultsHelpFormatter
from subprocess import check_output

sys_path.append('')

from ufo_interface import s_lorentz, s_color, write_model, write_run_card
from ufo_interface.templates import sconstruct_template
from ufo_interface.message import error, warning, progress

def parse_args():
      arg_parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
      arg_parser.add_argument("ufo_path",
                              help = "Path to input UFO model directory")
      arg_parser.add_argument("--ncore",
                              default=1,
                              help = "Number of cores used for compilation")
      arg_parser.add_argument("--modelflags",
                              default = "-g -O0 -fno-var-tracking --std=c++11",
                              help = 'Compiler flags for model source code')
      arg_parser.add_argument("--lorentzflags",
                              default = "-g -O2 -ffast-math --std=c++11",
                              help = 'Compiler flags for lorentz source code')
      arg_parser.add_argument("--includedir",
                              default = '/mt/home/dreichelt/batch/workspace/LH23-VBF/branch-223/sherpa/../install/include/SHERPA-MC',
                              help = 'Path to Sherpa headers')
      arg_parser.add_argument("--libdir",
                              default = '/mt/home/dreichelt/batch/workspace/LH23-VBF/branch-223/sherpa/../install/lib/SHERPA-MC',
                              help = 'Path to Sherpa libraries')
      arg_parser.add_argument("--installdir",
                              default = '/mt/home/dreichelt/batch/workspace/LH23-VBF/branch-223/sherpa/../install/lib/SHERPA-MC',
                              help = 'Installation path for new shared Model library')
      arg_parser.add_argument("--noopt",
                              action="store_true",
                              help = "Disable optimization of Lorentz calculators")

      return arg_parser.parse_args()

def check_color_struct(color, particles):
      """If the color structure contains "Identity", the assignment of
      fundamental/antifundamental/adjoint representation indices is
      often wrong. This function fixes this and returns a corrected
      color structure.
      """
      new_color = color

      # Collect all antifundamental indices in UFO numbering
      # scheme starting at 1
      af_inds = [i+1 for i,p in enumerate(particles) if p.color==-3]
      for a_ind in af_inds:
            search = "Identity({0},".format(a_ind)
            if search in new_color:
                  # Now we have an antifundamental index
                  # as the first argument, i.e.
                  # 'Identity(a_ind, f_ind)', which is
                  # wrong. Need to swap a_ind and f_ind
                  f_ind = new_color[new_color.find(search)+len(search)]
                  assert("Identity({0},{1})".format(a_ind,f_ind) in new_color)
                  new_color = new_color.replace("Identity({0},{1})".format(a_ind,f_ind),
                                                "Identity({0},{1})".format(f_ind,a_ind))

      # Collect all octet indices in UFO numbering scheme starting at
      # 1. Replace Identities with dedicated Identities for octets
      oc_inds = [i+1 for i,p in enumerate(particles) if p.color==8]
      while(len(oc_inds)>0):
            ind0 = oc_inds.pop()
            search = "Identity({0},".format(ind0)
            if search in new_color:
                  ind1 = new_color[new_color.find(search)+len(search)]
                  assert("Identity({0},{1})".format(ind0,ind1) in new_color)
                  new_color = new_color.replace("Identity({0},{1})".format(ind0,ind1),
                                                "IdentityG({0},{1})".format(ind0,ind1))
            
      return new_color

def sort_orders(orders):
      """Sort the coupling orders such that the first item is the QCD order
      and the second item is the QED order. Raise an exception if
      either of them is not contained in the model

      """
      # Find index of QCD coupling in the 'orders' list
      i_qcd = None
      for i,ord in enumerate(orders):
            if ord.name == 'QCD':
                  i_qcd = i
                  break
      if i_qcd == None:
            raise ufo_exception('No QCD coupling found in model')
            
      # Swap first item in list with QCD coupling
      orders[0], orders[i_qcd] = orders[i_qcd], orders[0]
            
      # Find index of QED coupling in the 'orders' list
      i_qed = None
      for i,ord in enumerate(orders):
            if ord.name == 'QED':
                  i_qed = i
                  break
      if i_qed == None:
            raise ufo_exception('No QED coupling found in model')

      # Swap second item in list with QED coupling
      orders[1], orders[i_qed] = orders[i_qed], orders[1]
      
      assert(orders[0].name=='QCD' and orders[1].name=='QED')
      return orders

def check_model(model_name, model_path):

      # Model name cannot begin with a number since class names in C++ cannot
      if model_name[0].isdigit():
            error("Model name \"{0}\" starts with a digit. Please raname your UFO model directory".format(model_name))
            exit(1)
      
      # Check for conflicts with built-in models
      if model_name in ["SM", "HEFT", "TauPi"]:
            error("Model name \"{0}\" conflicts with built-in model. Please rename your UFO model directory.".format(model_name))
            exit(1)
            
      # Try to import the model to check if UFO path is ok
      try:
            f, pathn, desc = find_module(model_name, [model_path])
            model = load_module(model_name, f, pathn, desc)
                  
      except ImportError as err:
            error("Could not import UFO model from input path \"{0}\", make sure this is a path to a valid UFO model".format(args.ufo_path))
            exit(1)

      for v in model.all_vertices:
            for i in range(len(v.color)):
                  old_color = v.color[i] 
                  new_color = check_color_struct(old_color, v.particles)
                  # Safety check: applying check_color_struct twice
                  # should not alter the result
                  assert(check_color_struct(new_color, v.particles)==new_color)
                  v.color[i] = new_color
            
      # For NLO models: need to filter out CT_Couplings
      model.non_ct_couplings = model.all_couplings
      if hasattr(model, "all_CTvertices"):
            non_ct_vertices  = [vtx for vtx in model.all_vertices if vtx not in model.all_CTvertices]
            model.non_ct_couplings = sum([vtx.couplings.values() for vtx in non_ct_vertices], [])

      # For NLO models: filter out lorentz structures 
      # that ONLY apper in CT vertices as they are somtimes ill-formatted
      needed_names = []
      for v in model.all_vertices:
            needed_names += [l.name for l in v.lorentz]
      model.all_lorentz = [l for l in model.all_lorentz if l.name in needed_names]

      # Sort coupling orders to ensure first order is QCD, second
      # order is QED
      model.all_orders = sort_orders(model.all_orders)

      return model

def make_output_dir(path):
      if not os_path.exists(path):
            makedirs(path)
      elif not os_path.isdir(path):
            error("Could not write to output path \"{0}\", file with the same name existing".format(path))
            exit(1)

if __name__ == "__main__":


      # Extract command line args
      args = parse_args()

      # Check the path to the UFO models
      arg_path = os_path.abspath(args.ufo_path)
      model_path, model_name = os_path.split((arg_path.rstrip('/')))
      model = check_model(model_name, model_path)

      # Output paths 
      out_dir = '{0}/.sherpa'.format(arg_path)
      sconstruct_file_path = os_path.join(out_dir, 'SConstruct')
      model_out_path = os_path.join(out_dir, 'Model.C')
      make_output_dir(out_dir)
      
      # Need this list in order to write model source code
      lorentzes = [s_lorentz(l) for l in model.all_lorentz]
      
      # Set of all color structures, no duplicates
      colors = set(sum([v.color for v in model.all_vertices],[]))
      colors = [col for col in colors if col!='1']

      # Write model source code
      progress("Generating model source code")
      write_model(model, lorentzes, model_name, model_out_path)
      
      # Write color calculator source code
      for col in colors:
            scol = s_color(col)
            progress("Generating source code for color calculator '{0}'".format(scol.name()))
            # Create an s_color istance temporaryly only in this
            # scope, so that its memory intensive color tensor
            # instance gets deleted as soon as it has been written to
            # file
            scol.write(out_dir)

      # Write lorentz calculator source code
      optimize = not args.noopt
      for lor in lorentzes:
            if not lor.has_ghosts():
                  progress("Generating source code for lorentz calculator '{0}'".format(lor.name()))
                  lor.write(out_dir,optimize)

      # Generate and write SConstruct file
      lib_name = 'Sherpa{0}'.format(model_name)
      with open(sconstruct_file_path, 'w') as sconstruct_file:
            sconstruct_file.write(sconstruct_template.substitute(libname = lib_name,
                                                                 includedir = args.includedir,
                                                                 libdir = args.libdir,
                                                                 installdir = args.installdir,
                                                                 modelflags = args.modelflags,
                                                                 lorentzflags = args.lorentzflags))

      # compile and install
      progress("Compiling sources using scons")
      sconsargs = ['scons','-C',out_dir,'install'] 
      if (args.ncore>1):
            sconsargs.append('-j{0}'.format(args.ncore))
      progress(check_output(sconsargs))

      # write example run card to working dir
      run_card_path="Sherpa.{0}_Example.yaml".format(model_name)
      while(os_path.exists(run_card_path)):
            run_card_path="_"+run_card_path
      progress("Writing example Sherpa config file '{0}' to working directory".format(run_card_path))
      write_run_card(model, model_name, run_card_path)

      progress("Finished generating model '{0}'\nPlease cite Eur.Phys.J. C, 75 3 (2015) 137\nif you make use of Sherpa's BSM features".format(model_name))

      exit(0)
