#!/usr/bin/env python
'''
Submit jobs on cluster via PBS qsub command
Gao Wang (copyleft) 2012

ChangeLog
---------
2014-May-20
	* Add "--batch" option
2014-May-15
	* Add "-H" option
2013-March-17
	* Allow for additional pbs attributes via "qsub -W" switch
2012-April-02
	* Initial implementation
'''
import sys, os, re, random, glob
from time import sleep
from subprocess import Popen, PIPE
try:
    from argparse import ArgumentParser
except:
    sys.exit('argparse module is required (available in Python 2.7.2+ or Python 3.2.1+)')

class qsub:
    def __init__(self, size, hold_size, name, nodes, ppn, mem, walltime, queue, directory, emails,
                 additional_attributes, parallel = False, pbs = None, dry_run = False):
        self.size = size
        self.hold_size = hold_size
        self.name = name
        self.nodes = nodes
        self.ppn = ppn
        self.mem = mem
        self.walltime = walltime
        self.queue = queue
        self.dir = directory
        self.emails = emails
        self.parallel = parallel
        self.pbs = pbs
        if not self.dir:
            self.dir = os.getcwd()
        if parallel:
            try:
                Popen(['parallel','--help'], shell = False, stdout = PIPE, stderr = PIPE)
            except OSError:
                sys.stderr.write("Parallel mode disabled since 'GNU parallel' is not installed.\n")
                self.parallel = False
        if not re.match("(.+?):(.+?):(.+?)", self.walltime):
            sys.exit('Invalid walltime {0}'.format(self.walltime))
        bad_addr = []
        for email in self.emails:
            if not re.match("^[a-zA-Z0-9._%-]+@[a-zA-Z0-9._%-]+.[a-zA-Z]{2,6}$", email):
                sys.stderr.write('Invalid email address ignored: {0}'.format(email))
                bad_addr.append(email)
        self.emails = [x for x in self.emails if x not in bad_addr]
        if dry_run:
            self.exe = ['echo']
        else:
            self.exe = ['qsub'] + ((['-W'] + additional_attributes) if len(additional_attributes) else [])
        
    def submit(self, cmds):
        if not self.size:
            self.size = len(cmds)
        i = j = 0
        while (i < len(cmds)):
            k = i
            i += self.size
            j += 1
            jext = '' if self.size == len(cmds) else '_%s' % j
            cmd = '#PBS -S /bin/bash\n#PBS -V\n#PBS -N %s\n' % (self.name + jext)
            if len(self.emails):
                cmd += '#PBS -M %s\n#PBS -m abe\n' % ', '.join(self.emails)
            if self.queue:
                cmd += '#PBS -q {0}\n'.format(random.choice(self.queue))
            if self.hold_size:
                cmd += '#PBS -t 1-{0}%{1}\n'.format(min(i, len(cmds)) - k, self.hold_size)
            cmd += '#PBS -l nodes=%s:ppn=%s,walltime=%s\n' % (self.nodes, self.ppn, self.walltime)
            if self.mem:
                cmd += '#PBS -l mem=%s\n' % self.mem
            cmd += '#PBS -o {0}.out\n#PBS -e {0}.err\n'.format(self.name + jext)
            # Di Zhang suggests using cd instead of #PBS -d
            # cmd += '#PBS -d %s\n' % self.dir
            cmd += 'cd %s\n' % self.dir
            # cmd += 'cat $PBS_NODEFILE > %s\n' % (os.path.join(self.dir, self.name + jext + '.node'))
            #
            if self.parallel:
                if self.nodes == 1:
                    cmd += 'parallel -j %s << EOF\n' % self.ppn
                else:
                    cmd += 'parallel -j %s -u --sshloginfile $PBS_NODEFILE --wd $PWD << EOF\n' % self.ppn
            #
            if self.hold_size:
                cmd += '\n'.join(['if [ $PBS_ARRAYID -eq {0} ]; then {1}; fi'.format(idx + 1, item)
                                  for idx, item in enumerate(cmds[k:min(i, len(cmds))])])
            else:
                cmd += '\n'.join(cmds[k:min(i, len(cmds))])
            #
            if self.parallel:
                cmd += '\nEOF\nexit 0\n'
            else:
                cmd += '\nexit 0\n'
            # submit job
            p = Popen(self.exe, shell = False, stdin = PIPE, stdout = PIPE, stderr = PIPE, close_fds = True)
            decode = True if sys.version_info[0] > 2 else False
            (cout, cerr) = p.communicate(cmd.encode(sys.getdefaultencoding()) if decode else cmd)
            # print job name
            output = cout.decode(sys.getdefaultencoding()).rstrip() if decode else cout.rstrip()
            if output:
                print(output)
            # keep backup scripts
            if self.pbs:
                with open(self.pbs + jext + '.pbs', 'wb') as f:
                    f.write(cmd.encode(sys.getdefaultencoding()) if decode else cmd)
            # take a break
            if self.exe != ['echo']:
                sleep(0.1)
            # sleep for a while
            if j % 250 == 0 and self.exe != ['echo']:
                sys.stderr.write('250 jobs submitted; taking a 10s break.\n')
                sleep(10)


class BatchSubmitter:
    def __init__(self, dirname):
        self.files = glob.glob(os.path.join(dirname, '*.pbs'))
        self.dir = os.path.join(dirname, 'submitted')
        os.system('rm -rf {0}; mkdir {0}'.format(self.dir))
        self.total = len(self.files)
        self.submitted = 0

    def submit(self, limit, interval):
        self.limit = limit
        print >> open(self.dir + '.limit', 'w'), limit
        done = False
        while not done:
            for i in range(self.__getslots()):
                try:
                    os.system('qsub {0}; mv {0} {1}'.format(self.files.pop(0), self.dir))
                    self.submitted += 1
                    sleep(0.5)
                except IndexError:
                    done = True
                    break
            print >> open(self.dir + '.log', 'w'), '%s/%s submitted' % (self.submitted, self.total)
            sleep(interval)

    def __getslots(self):
        os.system("qstat | awk '{print $5}' | grep -P 'Q|R|\-|Time' | wc -l &> %s" % (self.dir + '.qstat'))
        os.system("qstat | wc -l >> %s 2>&1" % (self.dir + '.qstat'))
        try:
            with open(self.dir + '.limit', 'r') as f:
                total = int(f.readlines()[0].split()[0])
        except:
            total = self.limit
        try:
            with open(self.dir + '.qstat', 'r') as f:
                counts = [int(x.split()[0]) for x in f.readlines()]
            if counts[0] < 2 and counts[1] > 0:
                # something is wrong with qstat
                counts[0] = total + 2
        except:
            counts = [total + 2, 0]
        return max(total - (counts[0] - 2), 0)


if __name__ == '__main__':
    parser = ArgumentParser(description='A handy program to send commands to the Portable Batch System (PBS) in cluster environments. Commands should be piped to this program from standard input stream.',epilog = '''2012 (copyleft) Gao Wang''' )
    parser.add_argument('name', help = 'job name')
    parser.add_argument('-s', '--size', metavar = 'N', type = int, help = 'number of commands per job, default to submitting all commands as one job')
    parser.add_argument('-H', '--hold-size', metavar = 'N', type = int, dest = 'hold_size',
                        help = 'triggers "array mode" per job and defines number of commands per hold')
    parser.add_argument('-n', '--nodes', metavar = 'N', type = int, default = 1, help = 'number of nodes reserved per job, default to 1')
    parser.add_argument('-p', '--ppn', metavar = 'N', type = int, default = 1, help = 'number of cpus reserved per node, default to 1')
    parser.add_argument('-m', '--mem', help = 'reserved memory with units be b (bytes), w (words), kb, kw, mb, mw, gb or gw')
    parser.add_argument('-w', '--walltime', default = '240:00:00', help = 'reserved computation time (hh:mm:ss), default to 240:00:00')
    parser.add_argument('-q', '--queue', nargs = '+', help = 'name of the queue to send jobs to, can be multiple queues to randomly chose from')
    parser.add_argument('-d', '--dir', help = 'directory under which the jobs will be executed, default to current directory')
    parser.add_argument('-e', '--emails', nargs = '+', default = [], help = 'email notification addresses')
    parser.add_argument('-a', '--additional_attributes', metavar = 'OPT', nargs = '+', default = [], help = 'additional qsub attributes, e.g., -a "group_list=bigmem"')
    parser.add_argument('--remove-colon', dest = 'rmcolon', action = 'store_true', help = 'replace colon with line break')
    parser.add_argument('--parallel', action = 'store_true',
                        help = 'use parallel mode, powered by GNU parallel, and will overwrite "--hold-size" option when triggered')
    parser.add_argument('--save', action = 'store_true', help = 'save pbs scripts to disk')
    parser.add_argument('--dry-run', action = 'store_true', dest = 'dry_run',
                        help = 'perform a trial run which saves pbs scripts to disk without submitting any job')
    parser.add_argument('--batch', metavar = 'N', nargs = 2, type = int, default = [0, 0],
                        help = 'submit N1 jobs at a time; check every N2 seconds to submit additional jobs if any of the previous N1 jobs are done')
    args = parser.parse_args()
    #
    try:
        if args.rmcolon:
            cmds = [y.strip() for y in ''.join([re.sub(r'\\(\s*)\n', ' ', re.sub('\r\n|(\s*);(\s*)','\n', x)) for x in sys.stdin.readlines() if not x.startswith('#') and len(x.strip()) > 0]).split('\n') if y.strip()]
        else:
            cmds = [y.strip() for y in ''.join([re.sub(r'\\(\s*)\n', ' ', x.replace('\r\n','\n')) for x in sys.stdin.readlines() if not x.startswith('#') and len(x.strip()) > 0]).split('\n') if y.strip()]
    except KeyboardInterrupt:
        sys.exit("Standard input stream expected")
    # various adjustments
    if len(cmds) == 0:
        sys.exit(0)
    if not args.size:
        args.size = len(cmds)
    if args.size >= len(cmds):
        args.batch[0] = 0
    if args.batch[0] > 0 and args.batch[1] > 0:
        args.dry_run = True
    if args.dry_run:
        args.save = True
    if args.dir:
        args.dir = os.path.abspath(os.path.expanduser(args.dir))
        if not os.path.isdir(args.dir):
            sys.exit("Cannot find directory {}".format(args.dir))
        os.chdir(args.dir)
    fn = dirname = None
    if args.save:
        if args.size < len(cmds):
            dirname = args.name + '_PBS'
            os.system('rm -rf %s; mkdir %s' % (dirname, dirname))
            fn = os.path.join(dirname, args.name)
        else:
            fn = args.name
    if args.parallel:
        args.hold_size = None
    if args.hold_size and args.size <= args.hold_size:
        sys.exit("Array hold blocks must be smaller than total of commands in the array!")
    # submit job
    q = qsub(args.size, args.hold_size, args.name, args.nodes, args.ppn, args.mem, args.walltime, args.queue,
             args.dir, args.emails, args.additional_attributes, args.parallel, fn, args.dry_run)
    try:
        q.submit(cmds)
    except Exception as e:
        sys.exit("ERROR: %s" % (e))
    # batch submit job
    if args.batch[0] > 0 and args.batch[1] > 0:
        bs = BatchSubmitter(dirname)
        bs.submit(args.batch[0], args.batch[1])