#! /usr/bin/env python

import os, sys

def ReplaceNaN(x,default):
    if not x==x:
        return default
    return x

class SherpaHistogram:
    def __init__(self, path, filename):
        self.path = path
        self.filename = filename
        self.data=[]
        self.is2D = False
        self.histo2D = SherpaHistogram2D(path, filename)
        self.read_input()

    def read_input(self):
        f = open(os.path.join(self.path, self.filename), "r")
        for line in f:
            linearray = line.split()
            if len(linearray)>10:
                self.is2D = True
                self.histo2D.read_input()
                break
            if len(linearray)>6:
                self.binwidth=(float(linearray[3])-float(linearray[2]))/(int(linearray[1])-2)
            elif len(linearray)==2:
                self.data.append({'LowEdge': float(linearray[0]),
                                  'Content': float(linearray[1]),
                                  'Error':   0.0})
            elif len(linearray)==3:
                self.data.append({'LowEdge': float(linearray[0]),
                                  'Content': float(linearray[1]),
                                  'Error':   float(linearray[2])})
            elif len(linearray)==4:
                self.data.append({'LowEdge': float(linearray[0]),
                                  'Content': float(linearray[1]),
                                  'Error':   float(linearray[2])})
            elif len(linearray)==5:
                self.data.append({'LowEdge': float(linearray[0]),
                                  'Content': float(linearray[1]),
                                  'Error':   float(linearray[2])})
            else:
                print "Can't handle 1D histo with line of length %s:\n %s" % (len(linearray), linearray)
        if not self.is2D:
            self.data.pop() # removes overflow bin
        f.close()

    def write_datapoint(self, f, xval, xerr, yval, yerr):
        if opts.REPLACE_NANS:
            xval=ReplaceNaN(xval,0)
            xerr=ReplaceNaN(xerr,0)
            yval=ReplaceNaN(yval,0)
            yerr=ReplaceNaN(yerr,0)
        f.write('%e\t %e\t %e\t %e\t %e\t %e\n' % (xval,xerr,xerr,yval,yerr,yerr))

    def write_datapointset_header(self, f, ananame):
        path = "/".join(self.path.split("/")[1:])
        if ananame is not None:
            path = "/"+ananame+path.replace("/","_")
        else:
            path = "/Sherpa"+path.replace("/","_")
        name = self.filename.replace("/","_")
        name = name.replace("+","plus")
        name = name.replace(".dat", "")
        title = self.filename.replace("_", " ")
        title = title.replace(".dat", "")
        f.write('# BEGIN YODA_SCATTER2D %s\n' % (path+"/"+name))
        f.write('Type=Scatter2D\n')
        f.write('Path=%s\n' % (path+"/"+name))
        f.write('Title=%s\n' % title)
        f.write('# xval   xerr-   xerr+   yval    yerr-   yerr+\n')

    def write_datapointset_footer(self, f):
        f.write('# END YODA_SCATTER2D\n\n')

    def write_datapointset(self, f, ananame):
        if self.is2D:
            self.histo2D.write_datapointset(f,ananame)
            return
        self.write_datapointset_header(f, ananame)
        for bindata in self.data:
            xval = bindata['LowEdge'] + 0.5*float(self.binwidth)
            xerr = 0.5*float(self.binwidth)
            yval = bindata['Content']
            yerr = bindata['Error']
            self.write_datapoint(f, xval, xerr, yval, yerr)
        self.write_datapointset_footer(f)


class SherpaHistogram2D:
    def __init__(self, path, filename):
        self.path = path
        self.filename = filename
        self.data=[]

    def read_input(self):
        f = open(os.path.join(self.path, self.filename), "r")
        for line in f:
            linearray = line.split()
            if len(linearray)>10:
                self.xbinwidth=(float(linearray[4])-float(linearray[3]))/(int(linearray[2]))
                self.ybinwidth=(float(linearray[7])-float(linearray[6]))/(int(linearray[5]))
            elif len(linearray)==3:
                self.data.append({'LowEdgeX': float(linearray[0]),
                                  'LowEdgeY': float(linearray[1]),
                                  'Content': float(linearray[2]),
                                  'Error':   0.0})
            elif len(linearray)==4:
                self.data.append({'LowEdgeX': float(linearray[0]),
                                  'LowEdgeY': float(linearray[1]),
                                  'Content': float(linearray[2]),
                                  'Error':   float(linearray[3])})
            else:
                print "Can't handle 2D histo with line of length %s:\n %s" % (len(linearray), linearray)
        f.close()

    def write_datapoint(self, f, xval, xerr, yval, yerr, zval, zerr):
        if opts.REPLACE_NANS:
            xval=ReplaceNaN(xval,0)
            xerr=ReplaceNaN(xerr,0)
            yval=ReplaceNaN(yval,0)
            yerr=ReplaceNaN(yerr,0)
            zval=ReplaceNaN(zval,0)
            zerr=ReplaceNaN(zerr,0)
        f.write('%e\t %e\t %e\t %e\t %e\t %e\t %e\t %e\t %e\n' % (xval,xerr,xerr,yval,yerr,yerr,zval,zerr,zerr))

    def write_datapointset_header(self, f, ananame):
        path = "/".join(self.path.split("/")[1:])
        if ananame is not None:
            path = "/"+ananame+path.replace("/","_")
        else:
            path = "/Sherpa"+path.replace("/","_")
        name = self.filename.replace("/","_")
        name = name.replace("+","plus")
        name = name.replace(".dat", "")
        title = self.filename.replace("_", " ")
        title = title.replace(".dat", "")
        f.write('# BEGIN YODA_SCATTER3D %s\n' % (path+"/"+name))
        f.write('Type=Scatter3D\n')
        f.write('Path=%s\n' % (path+"/"+name))
        f.write('Title=%s\n' % title)
        f.write('# xval   xerr-   xerr+   yval    yerr-   yerr+   zval    zerr-   zerr+\n')

    def write_datapointset_footer(self, f):
        f.write('# END YODA_SCATTER3D\n\n')

    def write_datapointset(self, f, ananame):
        self.write_datapointset_header(f, ananame)
        for bindata in self.data:
            xval = bindata['LowEdgeX'] + 0.5*float(self.xbinwidth)
            xerr = 0.5*float(self.xbinwidth)
            yval = bindata['LowEdgeY'] + 0.5*float(self.ybinwidth)
            yerr = 0.5*float(self.ybinwidth)
            zval = bindata['Content']
            zerr = bindata['Error']
            self.write_datapoint(f, xval, xerr, yval, yerr, zval, zerr)
        self.write_datapointset_footer(f)



from optparse import OptionParser
parser = OptionParser(usage="""%prog anadir1 [anadir2 ...]


Creates a YODA file for each input directory, containing all histograms
found in .dat files  (of the Sherpa internal analysis format) in the
given directory.
These YODA files can then be plotted using Rivet's plotting tools,
e.g. rivet-mkhtml anadir1.yoda [anadir2.yoda ...].
""")
parser.add_option("-a", "--ananame", dest="ANA_NAME", default=None, metavar="NAME",
                  help="give an optional path name for the analysis")
parser.add_option("-n", "--no-jX-dirs", dest="NO_JX_DIRS", action="store_true", default=False, metavar="NO_JX_DIRS",
                  help="omit jX subdirectories")
parser.add_option("-r", "--replace-nans", dest="REPLACE_NANS", action="store_true", default=False, metavar="REPLACE_NANS",
                  help="replace nan's with zeros")

opts, args = parser.parse_args()

if len(args) < 1:
    sys.stderr.write("Must specity at least one analysis directory\n")
    sys.exit(1)

for dir in args:
    if not os.path.isdir(dir):
        continue
    while dir[-1]=="/":
        dir = dir[0:-1]
    if dir[0:2]=="./":
        dir = dir[2:]
    print "Processing", dir
    f = open(dir+'.yoda', 'w')

    for path, dirs, files in os.walk(dir):
        if opts.NO_JX_DIRS and len(dirs) > 0:
            for testdir in dirs:
                if (len(testdir)<=3 and testdir[:1]=="j") or ("j[" in testdir):
                    dirs.remove(testdir)
                    print "Skipping %s/%s" %(path,testdir)
        for file in files:
            if not file[-4:]==".dat":
                continue
            print "Converting %s in %s" %(file, path)
            histo = SherpaHistogram(path, file)
            histo.write_datapointset(f, opts.ANA_NAME)
    f.close
