# -*- Mode: python; indent-tabs-mode: nil; tab-width: 40 -*-
# vim: set filetype=python:
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.

@imports('sys')
def die(*args):
    'Print an error and terminate configure.'
    log.error(*args)
    sys.exit(1)


@imports(_from='mozbuild.configure', _import='ConfigureError')
def configure_error(message):
    '''Raise a programming error and terminate configure.
    Primarily for use in moz.configure templates to sanity check
    their inputs from moz.configure usage.'''
    raise ConfigureError(message)

# A wrapper to obtain a process' output that returns the output generated
# by running the given command if it exits normally, and streams that
# output to log.debug and calls die or the given error callback if it
# does not.
@imports('subprocess')
@imports('sys')
@imports(_from='mozbuild.configure.util', _import='LineIO')
@imports(_from='mozbuild.shellutil', _import='quote')
def check_cmd_output(*args, **kwargs):
    onerror = kwargs.pop('onerror', None)

    with log.queue_debug():
        log.debug('Executing: `%s`', quote(*args))
        proc = subprocess.Popen(args, stdout=subprocess.PIPE,
                                stderr=subprocess.PIPE)
        stdout, stderr = proc.communicate()
        retcode = proc.wait()
        if retcode == 0:
            return stdout

        log.debug('The command returned non-zero exit status %d.',
                  retcode)
        for out, desc in ((stdout, 'output'), (stderr, 'error output')):
            if out:
                log.debug('Its %s was:', desc)
                with LineIO(lambda l: log.debug('| %s', l)) as o:
                    o.write(out)
        if onerror:
            return onerror()
        die('Command `%s` failed with exit status %d.' %
            (quote(*args), retcode))

@imports('os')
def is_absolute_or_relative(path):
    if os.altsep and os.altsep in path:
        return True
    return os.sep in path


@imports(_import='mozpack.path', _as='mozpath')
def normsep(path):
    return mozpath.normsep(path)


@imports('ctypes')
@imports(_from='ctypes', _import='wintypes')
@imports(_from='mozbuild.configure.constants', _import='WindowsBinaryType')
def windows_binary_type(path):
    """Obtain the type of a binary on Windows.

    Returns WindowsBinaryType constant.
    """
    GetBinaryTypeW = ctypes.windll.kernel32.GetBinaryTypeW
    GetBinaryTypeW.argtypes = [wintypes.LPWSTR, wintypes.POINTER(wintypes.DWORD)]
    GetBinaryTypeW.restype = wintypes.BOOL

    bin_type = wintypes.DWORD()
    res = GetBinaryTypeW(path, ctypes.byref(bin_type))
    if not res:
        die('could not obtain binary type of %s' % path)

    if bin_type.value == 0:
        return WindowsBinaryType('win32')
    elif bin_type.value == 6:
        return WindowsBinaryType('win64')
    # If we see another binary type, something is likely horribly wrong.
    else:
        die('unsupported binary type on %s: %s' % (path, bin_type))


@imports('ctypes')
@imports(_from='ctypes', _import='wintypes')
def get_GetShortPathNameW():
    GetShortPathNameW = ctypes.windll.kernel32.GetShortPathNameW
    GetShortPathNameW.argtypes = [wintypes.LPCWSTR, wintypes.LPWSTR,
                                  wintypes.DWORD]
    GetShortPathNameW.restype = wintypes.DWORD
    return GetShortPathNameW


@template
@imports('ctypes')
@imports('platform')
@imports(_from='mozbuild.shellutil', _import='quote')
def normalize_path():
    # Until the build system can properly handle programs that need quoting,
    # transform those paths into their short version on Windows (e.g.
    # c:\PROGRA~1...).
    if platform.system() == 'Windows':
        GetShortPathNameW = get_GetShortPathNameW()

        def normalize_path(path):
            path = normsep(path)
            if quote(path) == path:
                return path
            size = 0
            while True:
                out = ctypes.create_unicode_buffer(size)
                needed = GetShortPathNameW(path, out, size)
                if size >= needed:
                    if ' ' in out.value:
                        die("GetShortPathName returned a long path name. "
                            "Are 8dot3 filenames disabled?")
                    return normsep(out.value)
                size = needed

    else:
        def normalize_path(path):
            return normsep(path)

    return normalize_path

normalize_path = normalize_path()


# Locates the given program using which, or returns the given path if it
# exists.
# The `paths` parameter may be passed to search the given paths instead of
# $PATH.
@imports(_from='which', _import='which')
@imports(_from='which', _import='WhichError')
@imports('itertools')
@imports(_from='os', _import='pathsep')
def find_program(file, paths=None):
    try:
        if is_absolute_or_relative(file):
            return normalize_path(which(os.path.basename(file),
                                        [os.path.dirname(file)]))
        if paths:
            if not isinstance(paths, (list, tuple)):
                die("Paths provided to find_program must be a list of strings, "
                    "not %r", paths)
            paths = list(itertools.chain(
                *(p.split(pathsep) for p in paths if p)))
        return normalize_path(which(file, path=paths))
    except WhichError:
        return None


@imports('os')
@imports('subprocess')
@imports(_from='mozbuild.configure.util', _import='LineIO')
@imports(_from='tempfile', _import='mkstemp')
def try_invoke_compiler(compiler, language, source, flags=None, onerror=None):
    flags = flags or []

    if not isinstance(flags, (list, tuple)):
        die("Flags provided to try_compile must be a list of strings, "
            "not %r", paths)

    suffix = {
        'C': '.c',
        'C++': '.cpp',
    }[language]

    fd, path = mkstemp(prefix='conftest.', suffix=suffix)
    try:
        source = source.encode('ascii', 'replace')

        log.debug('Creating `%s` with content:', path)
        with LineIO(lambda l: log.debug('| %s', l)) as out:
            out.write(source)

        os.write(fd, source)
        os.close(fd)
        cmd = compiler + list(flags) + [path]
        kwargs = {'onerror': onerror}
        return check_cmd_output(*cmd, **kwargs)
    finally:
        os.remove(path)


def unique_list(l):
    result = []
    for i in l:
        if l not in result:
            result.append(i)
    return result


# Get values out of the Windows registry. This function can only be called on
# Windows.
# The `pattern` argument is a string starting with HKEY_ and giving the full
# "path" of the registry key to get the value for, with backslash separators.
# The string can contains wildcards ('*').
# The result of this functions is an enumerator yielding tuples for each
# match. Each of these tuples contains the key name matching wildcards
# followed by the value.
#
# Examples:
#   get_registry_values(r'HKEY_LOCAL_MACHINE\SOFTWARE\Microsoft\'
#                       r'Windows Kits\Installed Roots\KitsRoot*')
#   yields e.g.:
#     ('KitsRoot81', r'C:\Program Files (x86)\Windows Kits\8.1\')
#     ('KitsRoot10', r'C:\Program Files (x86)\Windows Kits\10\')
#
#   get_registry_values(r'HKEY_LOCAL_MACHINE\SOFTWARE\Microsoft\'
#                       r'Windows Kits\Installed Roots\KitsRoot8.1')
#   yields e.g.:
#     (r'C:\Program Files (x86)\Windows Kits\8.1\',)
#
#   get_registry_values(r'HKEY_LOCAL_MACHINE\SOFTWARE\Microsoft\'
#                       r'Windows Kits\*\KitsRoot*')
#   yields e.g.:
#     ('Installed Roots', 'KitsRoot81',
#      r'C:\Program Files (x86)\Windows Kits\8.1\')
#     ('Installed Roots', 'KitsRoot10',
#      r'C:\Program Files (x86)\Windows Kits\10\')
#
#   get_registry_values(r'HKEY_LOCAL_MACHINE\SOFTWARE\Microsoft\'
#                       r'VisualStudio\VC\*\x86\*\Compiler')
#   yields e.g.:
#     ('19.0', 'arm', r'C:\...\amd64_arm\cl.exe')
#     ('19.0', 'x64', r'C:\...\amd64\cl.exe')
#     ('19.0', 'x86', r'C:\...\amd64_x86\cl.exe')
@imports(_import='_winreg', _as='winreg')
@imports(_from='__builtin__', _import='WindowsError')
@imports(_from='fnmatch', _import='fnmatch')
def get_registry_values(pattern):
    def enum_helper(func, key):
        i = 0
        while True:
            try:
                yield func(key, i)
            except WindowsError:
                break
            i += 1

    def get_keys(key, pattern):
        try:
            s = winreg.OpenKey(key, '\\'.join(pattern[:-1]))
        except WindowsError:
            return
        for k in enum_helper(winreg.EnumKey, s):
            if fnmatch(k, pattern[-1]):
                try:
                    yield k, winreg.OpenKey(s, k)
                except WindowsError:
                    pass

    def get_values(key, pattern):
        try:
            s = winreg.OpenKey(key, '\\'.join(pattern[:-1]))
        except WindowsError:
            return
        for k, v, t in enum_helper(winreg.EnumValue, s):
            if fnmatch(k, pattern[-1]):
                yield k, v

    def split_pattern(pattern):
        subpattern = []
        for p in pattern:
            subpattern.append(p)
            if '*' in p:
                yield subpattern
                subpattern = []
        if subpattern:
            yield subpattern

    pattern = pattern.split('\\')
    assert pattern[0].startswith('HKEY_')
    keys = [(getattr(winreg, pattern[0]),)]
    pattern = list(split_pattern(pattern[1:]))
    for i, p in enumerate(pattern):
        next_keys = []
        for base_key in keys:
            matches = base_key[:-1]
            base_key = base_key[-1]
            if i == len(pattern) - 1:
                want_name = '*' in p[-1]
                for name, value in get_values(base_key, p):
                    yield matches + ((name, value) if want_name else (value,))
            else:
                for name, k in get_keys(base_key, p):
                    next_keys.append(matches + (name, k))
        keys = next_keys


@imports(_from='mozbuild.configure.util', _import='Version', _as='_Version')
def Version(v):
    'A version number that can be compared usefully.'
    return _Version(v)

# Denotes a deprecated option. Combines option() and @depends:
# @deprecated_option('--option')
# def option(value):
#     ...
# @deprecated_option() takes the same arguments as option(), except `help`.
# The function may handle the option like a typical @depends function would,
# but it is recommended it emits a deprecation error message suggesting an
# alternative option to use if there is one.
@template
def deprecated_option(*args, **kwargs):
    assert 'help' not in kwargs
    kwargs['help'] = 'Deprecated'
    opt = option(*args, **kwargs)

    def decorator(func):
        @depends(opt.option)
        def deprecated(value):
            if value.origin != 'default':
                return func(value)
        return deprecated

    return decorator


# from mozbuild.util import ReadOnlyNamespace as namespace
@imports(_from='mozbuild.util', _import='ReadOnlyNamespace')
def namespace(**kwargs):
    return ReadOnlyNamespace(**kwargs)


# Turn an object into an object that can be used as an argument to @depends.
# The given object can be a literal value, a function that takes no argument,
# or, for convenience, a @depends function.
@template
@imports(_from='inspect', _import='isfunction')
@imports(_from='mozbuild.configure', _import='SandboxDependsFunction')
def dependable(obj):
    if isinstance(obj, SandboxDependsFunction):
        return obj
    if isfunction(obj):
        return depends(when=True)(obj)
    return depends(when=True)(lambda: obj)


always = dependable(True)
never = dependable(False)


# Some @depends function return namespaces, and one could want to use one
# specific attribute from such a namespace as a "value" given to functions
# such as `set_config`. But those functions do not take immediate values.
# The `delayed_getattr` function allows access to attributes from the result
# of a @depends function in a non-immediate manner.
#   @depends('--option')
#   def option(value)
#       return namespace(foo=value)
#   set_config('FOO', delayed_getattr(option, 'foo')
@template
def delayed_getattr(func, key):
    @depends(func)
    def result(value):
        # The @depends function we're being passed may have returned
        # None, or an object that simply doesn't have the wanted key.
        # In that case, just return None.
        return getattr(value, key, None)

    return result


# Like @depends, but the decorated function is only called if one of the
# arguments it would be called with has a positive value (bool(value) is True)
@template
def depends_if(*args):
    def decorator(func):
        @depends(*args)
        def wrapper(*args):
            if any(arg for arg in args):
                return func(*args)
        return wrapper
    return decorator

# Like @depends_if, but a distinguished value passed as a keyword argument
# "when" is truth tested instead of every argument. This value is not passed
# to the function if it is called.
@template
def depends_when(*args, **kwargs):
    if not len(kwargs) == 1 and kwargs.get('when'):
        die('depends_when requires a single keyword argument, "when"')
    when = kwargs['when']
    if not when:
        return depends(*args)

    def decorator(fn):
        @depends(when, *args)
        def wrapper(val, *args):
            if val:
                return fn(*args)
        return wrapper
    return decorator


# Hacks related to old-configure
# ==============================

@dependable
def old_configure_assignments():
    return []

@dependable
def extra_old_configure_args():
    return []

@template
def add_old_configure_assignment(var, value):
    var = dependable(var)
    value = dependable(value)

    @depends(old_configure_assignments, var, value)
    @imports(_from='mozbuild.shellutil', _import='quote')
    def add_assignment(assignments, var, value):
        if var is None or value is None:
            return
        if value is True:
            assignments.append('%s=1' % var)
        elif value is False:
            assignments.append('%s=' % var)
        else:
            if isinstance(value, (list, tuple)):
                value = quote(*value)
            assignments.append('%s=%s' % (var, quote(str(value))))

@template
def add_old_configure_arg(arg):
    @depends(extra_old_configure_args, arg)
    def add_arg(args, arg):
        if arg:
            args.append(arg)