#!/usr/bin/env python

# Usage: sage -tp N <options> <files>
#
# <options> may include
#      --long                include lines with the phrase 'long time'
#      --verbose             debugging output during the test
#      --optional            also test all #optional examples
#      --only-optional <tag1,...,tagn>    only run tests including one
#                                         of the #optional tags
#      --randorder[=seed]    randomize order of tests
#
# This runs doctests on <files> in parallel using N threads.  If N is
# zero or omitted, set N to the value of the environment variable
# SAGE_NUM_THREADS_PARALLEL, which is set in sage-env.

import os
import shutil
import sys
import time
import pickle
import signal
import thread
import tempfile
import subprocess
import multiprocessing
import socket
import stat
import re

def usage():
    print """Usage: sage -tp N <options> <files>

<options> may include
     --long                include lines with the phrase 'long time'
     --verbose             debugging output during the test
     --optional            also test all #optional examples
     --only-optional <tag1,...,tagn>    only run tests including one
                                        of the #optional tags
     --randorder[=seed]    randomize order of tests

This runs doctests on <files> in parallel using N threads.  If N is
zero or omitted, try to use a sensible default number of threads: if
the '-j' flag of the environment variable 'MAKE' or 'MAKEFLAGS' is set,
use that setting. Otherwise, use min(8, number of CPU cores)."""

if len(sys.argv) == 1:
    usage()
    exit(1)

SAGE_ROOT = os.path.realpath(os.environ['SAGE_ROOT'])
SAGE_SITE = os.path.realpath(os.path.join(os.environ['SAGE_LOCAL'],
                                          'lib', 'python', 'site-packages'))
BUILD_DIR = os.path.realpath(os.path.join(SAGE_ROOT, 'devel', 'sage', 'build'))

try:
    XML_RESULTS = os.environ['XML_RESULTS']
except KeyError:
    XML_RESULTS = None

numiteration = int(os.environ.get('SAGE_TEST_ITER', 1))
numglobaliteration = int(os.environ.get('SAGE_TEST_GLOBAL_ITER', 1))
print 'Global iterations: ' + str(numglobaliteration)
print 'File iterations: ' + str(numiteration)

# Exit status for the whole run.
err = 0

def strip_automount_prefix(filename):
    """
    Strip prefixes added on automounted filesystems in some cases,
    which make the absolute path appear hidden.

    AUTHOR:
        -- Kate Minola
    """
    sep = os.path.sep
    str = filename.split(sep,2)
    if len(str) < 2:
        new = sep
    else:
        new = sep + str[1]
    if os.path.exists(new):
        inode1 = os.stat(filename)[1]
        inode2 = os.stat(new)[1]
        if inode1 == inode2:
            filename = new
    return filename

def abspath(x):
    """
    This function returns the absolute path (adjusted for NFS)
    """
    return strip_automount_prefix(os.path.abspath(x))

def abs_sage_path(f):
    """
    Return the absolute path, relative to the sage root or current directory
    """
    global CUR

    abs_path = abspath(f)
    if abs_path.startswith(SAGE_ROOT):
        abs_path = abs_path[len(SAGE_ROOT) + 1:]
    elif abs_path.startswith(CUR):
        abs_path = abs_path[len(CUR) + 1:]

    return abs_path

def test_cmd(f):
    """
    Return the test command for the file
    """
    global opts
    return "sage -t %s %s" % (opts, abs_sage_path(f))

def skip(F):
    """
    Returns true if the file should not be tested
    """
    if not os.path.exists(F):
        # XXX IMHO this should never happen; in case it does, it's certainly
        #     an error to be reported (either filesystem, or bad name specified
        #     on the command line). -leif
        return True
    G = abspath(F)
    i = G.rfind(os.path.sep)
    # XXX The following should IMHO be performed in populatefilelist():
    #     (Currently, populatefilelist() only looks for "__nodoctest__".)
    if os.path.exists(os.path.join(G[:i], 'nodoctest.py')):
        printmutex.acquire()
        print "%s (skipping) -- nodoctest.py file in directory" % test_cmd(F)
        sys.stdout.flush()
        printmutex.release()
        return True
    filenm = os.path.split(F)[1]
    if (filenm[0] == '.' or (os.path.sep + '.' in G.lstrip(os.path.sep + '.'))
        or 'nodoctest' in open(G).read()[:50]):
        return True
    if G.find(os.path.join('doc', 'output')) != -1:
        return True
    # XXX The following is (also/already) handled in populatefilelist():
    if not (os.path.splitext(F)[1] in ['.py', '.pyx', '.spyx', '.tex', '.pxi', '.sage', '.rst']):
        return True
    return False

def test_file(F):
    """
    This is the function that actually tests a file
    """
    global opts
    outfile = tempfile.NamedTemporaryFile()
    base, ext = os.path.splitext(F)

    cmd = 'doctest ' + opts
    if SAGE_SITE in os.path.realpath(F) and not '-force_lib' in cmd:
        cmd += ' -force_lib'

    filestr = os.path.split(F)[1]
    for i in range(0,numiteration):
        os.chdir(os.path.dirname(F))
        command = os.path.join(SAGE_ROOT, 'local', 'bin', 'sage-%s' % cmd)
        # FIXME: Why call bash here? (Also, we use 'shell=True' below anyway.)
        s = 'bash -c "%s %s > %s" ' % (command, filestr, outfile.name)
        try:
            t = time.time()
            ret = subprocess.call(s, shell=True)
            finished_time = time.time() - t
        except:
            ol = outfile.read()
            return (F, 32, 0, ol)
        ol = outfile.read()
        if ret != 0:
            break
    return (F, ret, finished_time, ol)

def process_result(result):
    """
    This file takes a tuple in the form
    (F, ret, finished_time, ol)
    and processes it to display/log the appropriate output.
    """
    global err, failed, time_dict
    F = result[0]
    ret = result[1]
    finished_time = result[2]
    ol = result[3]
    err = err | ret
    if ret != 0:
        if ret == 128:
            numfail = ol.count('Expected:') + ol.count('Expected nothing') + ol.count('Exception raised:')
            failed.append(test_cmd(F) + (" # %s doctests failed" % numfail))
            ret = numfail
        elif ret == 64:
            failed.append(test_cmd(F) + " # Time out")
        elif ret == 8:
            failed.append(test_cmd(F) + " # Exception from doctest framework")
        elif ret == 4:
            failed.append(test_cmd(F) + " # Killed/crashed")
        elif ret == 2:
            failed.append(test_cmd(F) + " # KeyboardInterrupt")
        elif ret == 1:
            failed.append(test_cmd(F) + " # File not found")
        else:
            failed.append(test_cmd(F))

    print test_cmd(F)
    sys.stdout.flush()

    if ol!="" and (not ol.isspace()):
        if (ol[len(ol)-1]=="\n"):
            ol=ol[0:len(ol)-1]
        print ol
        sys.stdout.flush()
    time_dict[abs_sage_path(F)] = finished_time
    if XML_RESULTS:
        t = finished_time
        failures = int(ret == 128)
        errors = int(ret and not failures)
        path = F.split(os.path.sep)
        while 'sage' in path:
            path = path[path.index('sage')+1:]
        path[-1] = os.path.splitext(path[-1])[0]
        module = '.'.join(path)
        if (failures or errors) and '#' in failed[-1]:
            type = "error" if errors else "failure"
            cause = failed[-1].split('#')[-1].strip()
            failure_item = "<%s type='%s'>%s</%s>" % (type, cause, cause, type)
        else:
            failure_item = ""
        f = open(os.path.join(XML_RESULTS, module + '.xml'), 'w')
        f.write("""
            <?xml version="1.0" ?>
            <testsuite name="%(module)s" errors="%(errors)s" failures="%(failures)s" tests="1" time="%(t)s">
            <testcase classname="%(module)s" name="test">
            %(failure_item)s
            </testcase>
            </testsuite>
        """.strip() % locals())
        f.close()
    print "\t [%.1f s]"%(finished_time)
    sys.stdout.flush()

def infiles_cmp(a,b):
    """
    This compare function is used to sort the list of filenames by the time they take to run
    """
    global time_dict
    if time_dict.has_key(abs_sage_path(a)):
        if time_dict.has_key(abs_sage_path(b)):
            return cmp(time_dict[abs_sage_path(a)],time_dict[abs_sage_path(b)])
        else:
            return 1
    else:
        return -1

def populatefilelist(filelist):
    """
    This populates the file list by expanding directories into lists of files
    """
    global CUR
    filemutex.acquire()
    for FF in filelist:
        if os.path.isfile(FF):
            if skip(FF):
                continue
            if not os.path.isabs(FF):
                cwd = os.getcwd()
                files.append(os.path.join(cwd, FF))
            else:
                files.append(FF)
            continue

        curdir = os.getcwd()
        walkdir = os.path.join(CUR,FF)

        for root, dirs, lfiles in os.walk(walkdir):
            for F in lfiles:
                base, ext = os.path.splitext(F)
                if not (ext in ['.sage', '.py', '.pyx', '.spyx', '.tex', '.pxi', '.rst']):
                    continue
                elif '__nodoctest__' in files:
                    # XXX Shouldn't this be 'lfiles'?
                    # Also, this test should IMHO be in the outer loop (1 level).
                    # Furthermore, the current practice is to put "nodoctest.py"
                    # files in the directories that should be skipped, not
                    # "__nodoctest__". (I haven't found a single instance of the
                    # latter in Sage 4.6.1.alpha3.)
                    # "nodoctest.py" is handled in skip() (!), to also be fixed.
                    # -leif
                    continue
                appendstr = os.path.join(root,F)
                if skip(appendstr):
                    continue
                if os.path.realpath(appendstr).startswith(BUILD_DIR):
                    continue
                files.append(appendstr)
            for D in dirs:
                if '#' in D or (os.path.sep + 'notes' in D):
                    dirs.remove(D)
    filemutex.release()
    return 0

for gr in range(0,numglobaliteration):
    argv = sys.argv
    opts = ' '.join([X for X in argv if X[0] == '-'])
    argv = [X for X in argv if X[0] != '-']

    try:
        numthreads = int(argv[1])
        infiles = argv[2:]
    except ValueError:
        # can't convert first arg to an integer: arg was probably omitted
        numthreads = 0
        infiles = argv[1:]

    if '-sagenb' in opts:
        opts = opts.replace('--sagenb', '').replace('-sagenb', '')

        # Find SageNB's home.
        from pkg_resources import Requirement, working_set
        sagenb_loc = working_set.find(Requirement.parse('sagenb')).location

        # In case we're using setuptools' "develop" mode.
        if not SAGE_SITE in sagenb_loc:
            opts += ' -force_lib'

        infiles.append(os.path.join(sagenb_loc, 'sagenb'))

    verbose = ('-verbose' in opts or '--verbose' in opts)

    if numthreads == 0:
        try:
            numthreads = int(os.environ['SAGE_NUM_THREADS_PARALLEL'])
        except KeyError:
            numthreads = 1

    if numthreads < 1 or len(infiles) == 0:
        if numthreads < 1:
            print "Usage: sage -tp <numthreads> <files or directories>: <numthreads> must be non-negative."
        else:
            print "Usage: sage -tp <numthreads> <files or directories>: no files or directories specified."
        print "For more information, type 'sage --advanced'."
        sys.exit(1)

    infiles.sort()

    files = list()

    t0 = time.time()
    filemutex = thread.allocate_lock()
    printmutex = thread.allocate_lock()
    SAGE_TESTDIR = os.environ['SAGE_TESTDIR']
    #Pick a filename for the timing files -- long vs normal
    # TODO: perhaps these files shouldn't be hidden?  Also, don't
    # store them in SAGE_TESTDIR, in case the user wants to test in
    # some temporary directory: store them somewhere more permanent.
    if opts.count("-long"):
        time_file_name = os.path.join(SAGE_TESTDIR,
                                      '.ptest_timing_long')
    else:
        time_file_name = os.path.join(SAGE_TESTDIR,
                                      '.ptest_timing')
    time_dict = { }
    try:
        with open(time_file_name) as time_file:
            time_dict = pickle.load(time_file)
        if opts.count("-long"):
            print "Using long cached timings to run longest doctests first."
        else:
            print "Using cached timings to run longest doctests first."
        from copy import copy
        time_dict_old = copy(time_dict)
    except:
        time_dict = { }
        if opts.count("-long"):
            print "No long cached timings exist; will create for successful files."
        else:
            print "No cached timings exist; will create for successful files."
        time_dict_old = None
    done = False

    CUR = abspath(os.getcwd())

    failed = []

    HOSTNAME = socket.gethostname().replace('-','_').replace('/','_').replace('\\','_')
    # Should TMP be a subdirectory of tempfile.gettempdir() rather than SAGE_TESTDIR?
    TMP = os.path.join(SAGE_TESTDIR, '%s-%s' % (HOSTNAME, os.getpid()))
    TMP = os.path.abspath(TMP)
    try:
        os.makedirs(TMP)
    except OSError:
        # If TMP already exists, remove it and re-create it
        if os.path.isdir(TMP):
            shutil.rmtree(TMP)
            os.makedirs(TMP)
        else:
            raise

    # Add rwx permissions for user to TMP:
    os.chmod(TMP, os.stat(TMP)[0] | stat.S_IRWXU)
    os.environ['SAGE_TESTDIR'] = TMP
    if verbose:
        print
        print "Using the directory"
        print "   '%s'." % TMP
        print "for doctesting.  If all doctests pass, this directory will"
        print "be deleted automatically."
        print

    populatefilelist(infiles)
    #Sort the files by test time
    files.sort(infiles_cmp)
    files.reverse()
    interrupt = False

    numthreads = min(numthreads, len(files))  # don't use more threads than files
        
    if len(files) == 1:
        file_str = "1 file"
    else:
        file_str = "%i files" % len(files)
    if numthreads == 1:
        jobs_str = "using 1 thread"  # not in parallel if numthreads is 1
    else:
        jobs_str = "doing %s jobs in parallel" % numthreads

    print "Doctesting %s %s" % (file_str, jobs_str)
        
    try:
        p = multiprocessing.Pool(numthreads)
        for r in p.imap_unordered(test_file, files):
            #The format is  (F, ret, finished_time, ol)
            process_result(r)
    except KeyboardInterrupt:
        err = err | 2
        interrupt = True
        pass
    print " "
    print "-"*int(70)

    os.chdir(CUR)

    if verbose:
        print
        print "Removing the directory '%s'." % TMP
    try:
        os.rmdir(TMP)
        if verbose:
            print
            print "-"*int(70)
    except OSError:
        # TODO (probably in sage-doctest): if tests were interrupted
        # but there were no failures in the interrupted files, delete
        # the temporary files, so that this directory is empty.
        print "The temporary doctesting directory"
        print "   %s" % TMP
        print "was not removed: it is not empty, presumably because doctests"
        print "failed or doctesting was interrupted."
        print
        print "-"*int(70)

    if len(failed) == 0:
        if interrupt == False:
            print "All tests passed!"
        else:
            print "Keyboard Interrupt: All tests that ran passed."
    else:
        if interrupt:
            print "Keyboard Interrupt, not all tests ran"
        elif opts=="-long" or len(opts)==0:
            time_dict_ran = time_dict
            time_dict = { }
            failed_files = { }
            if opts=="-long":
                for F in failed:
                    failed_files[F.split('#')[0].split()[3]] = None
            else:
                for F in failed:
                    failed_files[F.split('#')[0].split()[2]] = None
            for F in time_dict_ran:
                if F not in failed_files:
                    time_dict[F] = time_dict_ran[F]
            if time_dict_old is not None:
                for F in time_dict_old:
                    if F not in time_dict:
                        time_dict[F] = time_dict_old[F]
        print "\nThe following tests failed:\n"
        for i in range(len(failed)):
               print "\t", failed[i]
        print "-"*int(70)

    #Only update timings if we are doing something standard
    opts = opts.strip()
    if (opts=="-long" or len(opts)==0) and not interrupt:
        with open(time_file_name,"w") as time_file:
            pickle.dump(time_dict, time_file)
            print "Timings have been updated."

    print "Total time for all tests: %.1f seconds"%(time.time() - t0)

sys.exit(err)
