#!/usr/bin/python
# Copyright 2011-2014 Red Hat, Inc.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301 USA
#
# Refer to the README and COPYING files for full details of the license
#
from pwd import getpwnam
import platform
import sys
import os
import stat
import errno
from functools import wraps
import threading
import re
import getopt
import resource
import signal
import logging
import logging.config
from vdsm.infra import sigutils
from vdsm.infra import zombiereaper

from caps import Architecture
import numaUtils

LOG_CONF_PATH = "/etc/vdsm/svdsm.logger.conf"

try:
    logging.config.fileConfig(LOG_CONF_PATH, disable_existing_loggers=False)
except:
    logging.basicConfig(filename='/dev/stdout', filemode='w+',
                        level=logging.DEBUG)
    log = logging.getLogger("SuperVdsm.Server")
    log.warn("Could not init proper logging", exc_info=True)

from storage import fuser
from multiprocessing import Pipe, Process
try:
    from gluster import listPublicFunctions
    _glusterEnabled = True
except ImportError:
    _glusterEnabled = False

from vdsm import utils
from vdsm import sysctl
from vdsm.tool import restore_nets
from parted_utils import getDevicePartedInfo as _getDevicePartedInfo

from network import sourceroutethread
from network.api import editNetwork, setupNetworks, setSafeNetworkConfig,\
    change_numvfs

from network.tc import setPortMirroring, unsetPortMirroring
from storage.multipath import getScsiSerial as _getScsiSerial
from storage.iscsi import getDevIscsiInfo as _getdeviSCSIinfo
from storage.iscsi import readSessionInfo as _readSessionInfo
from supervdsm import _SuperVdsmManager
from storage import hba
from storage import multipath
from storage.fileUtils import chown, resolveGid, resolveUid
from storage.fileUtils import validateAccess as _validateAccess
from vdsm.constants import METADATA_GROUP, EXT_CHOWN, EXT_UDEVADM, \
    DISKIMAGE_USER, DISKIMAGE_GROUP, \
    P_LIBVIRT_VMCHANNELS, P_OVIRT_VMCONSOLES, \
    VDSM_USER, QEMU_PROCESS_USER, QEMU_PROCESS_GROUP, RHEV_ENABLED
from storage.devicemapper import _removeMapping, _getPathsStatus
from vdsm.config import config
import mkimage

_UDEV_RULE_FILE_DIR = "/etc/udev/rules.d/"
_UDEV_RULE_FILE_PREFIX = "99-vdsm-"
_UDEV_RULE_FILE_EXT = ".rules"
_UDEV_RULE_FILE_NAME = os.path.join(
    _UDEV_RULE_FILE_DIR, _UDEV_RULE_FILE_PREFIX + '%s-%s' +
    _UDEV_RULE_FILE_EXT)
_UDEV_RULE_FILE_NAME_VFIO = os.path.join(
    _UDEV_RULE_FILE_DIR, _UDEV_RULE_FILE_PREFIX + "iommu_group_%s" +
    _UDEV_RULE_FILE_EXT)

RUN_AS_TIMEOUT = config.getint("irs", "process_pool_timeout")

_running = True


class Timeout(RuntimeError):
    pass


def logDecorator(func):
    callbackLogger = logging.getLogger("SuperVdsm.ServerCallback")

    @wraps(func)
    def wrapper(*args, **kwargs):
        callbackLogger.debug('call %s with %s %s',
                             func.__name__, args[1:], kwargs)
        try:
            res = func(*args, **kwargs)
        except:
            callbackLogger.error("Error in %s", func.__name__, exc_info=True)
            raise
        callbackLogger.debug('return %s with %s',
                             func.__name__, res)
        return res
    return wrapper


class _SuperVdsm(object):

    UDEV_WITH_RELOAD_VERSION = 181

    log = logging.getLogger("SuperVdsm.ServerCallback")

    @logDecorator
    def ping(self, *args, **kwargs):
        # This method exists for testing purposes
        return True

    @logDecorator
    def getHardwareInfo(self, *args, **kwargs):
        if platform.machine() in ('x86_64', 'i686'):
            from dmidecodeUtil import getHardwareInfoStructure
            return getHardwareInfoStructure()
        elif platform.machine() in Architecture.POWER:
            from ppc64HardwareInfo import getHardwareInfoStructure
            return getHardwareInfoStructure()
        else:
            #  not implemented over other architecture
            return {}

    @logDecorator
    def getDevicePartedInfo(self, *args, **kwargs):
        return _getDevicePartedInfo(*args, **kwargs)

    @logDecorator
    def getScsiSerial(self, *args, **kwargs):
        return _getScsiSerial(*args, **kwargs)

    @logDecorator
    def resizeMap(self, devName):
        return multipath._resize_map(devName)

    @logDecorator
    def removeDeviceMapping(self, devName):
        return _removeMapping(devName)

    @logDecorator
    def getdeviSCSIinfo(self, *args, **kwargs):
        return _getdeviSCSIinfo(*args, **kwargs)

    @logDecorator
    def readSessionInfo(self, sessionID):
        return _readSessionInfo(sessionID)

    @logDecorator
    def getPathsStatus(self):
        return _getPathsStatus()

    @logDecorator
    def getVmPid(self, vmName):
        pidFile = "/var/run/libvirt/qemu/%s.pid" % vmName
        with open(pidFile) as pid:
            return pid.read()

    @logDecorator
    def getVcpuNumaMemoryMapping(self, vmName):
        vmPid = self.getVmPid(vmName)
        vCpuPids = numaUtils.getVcpuPid(vmName)
        vCpuIdxToNode = {}
        for vCpuIndex, vCpuPid in vCpuPids.iteritems():
            numaMapsFile = "/proc/%s/task/%s/numa_maps" % (vmPid, vCpuPid)
            try:
                with open(numaMapsFile, 'r') as f:
                    mappingNodes = map(
                        int, re.findall('N(\d+)=\d+', f.read()))
                    vCpuIdxToNode[vCpuIndex] = list(set(mappingNodes))
            except IOError:
                continue
        return vCpuIdxToNode

    @logDecorator
    def prepareVmChannel(self, socketFile, group=None):
        if (socketFile.startswith(P_LIBVIRT_VMCHANNELS) or
           socketFile.startswith(P_OVIRT_VMCONSOLES)):
            fsinfo = os.stat(socketFile)
            mode = fsinfo.st_mode | stat.S_IWGRP
            os.chmod(socketFile, mode)
            if group is not None:
                os.chown(socketFile,
                         fsinfo.st_uid,
                         resolveGid(group))
        else:
            raise Exception("Incorporate socketFile")

    @logDecorator
    def restoreNetworks(self):
        return restore_nets.restore()

    @logDecorator
    def editNetwork(self, oldBridge, newBridge, options):
        return editNetwork(oldBridge, newBridge, **options)

    @logDecorator
    def setupNetworks(self, networks, bondings, options):
        return setupNetworks(networks, bondings, **options)

    @logDecorator
    def changeNumvfs(self, device_name, numvfs):
        return change_numvfs(device_name, numvfs)

    def _runAs(self, user, groups, func, args=(), kwargs={}):
        def child(pipe):
            res = ex = None
            try:
                uid = resolveUid(user)
                if groups:
                    gids = map(resolveGid, groups)

                    os.setgid(gids[0])
                    os.setgroups(gids)
                os.setuid(uid)

                res = func(*args, **kwargs)
            except BaseException as e:
                ex = e

            pipe.send((res, ex))
            pipe.recv()

        pipe, hisPipe = Pipe()
        proc = Process(target=child, args=(hisPipe,))
        proc.start()

        needReaping = True
        try:
            if not pipe.poll(RUN_AS_TIMEOUT):
                try:

                    os.kill(proc.pid, signal.SIGKILL)
                except OSError as e:
                    # Don't add to zombiereaper of PID no longer exists
                    if e.errno == errno.ESRCH:
                        needReaping = False
                    else:
                        raise

                raise Timeout()

            res, err = pipe.recv()
            pipe.send("Bye")
            proc.terminate()

            if err is not None:
                raise err

            return res

        finally:
            # Add to zombiereaper if process has not been waited upon
            if proc.exitcode is None and needReaping:
                zombiereaper.autoReapPID(proc.pid)

    @logDecorator
    def validateAccess(self, user, groups, *args, **kwargs):
        return self._runAs(user, groups, _validateAccess, args=args,
                           kwargs=kwargs)

    @logDecorator
    def setSafeNetworkConfig(self):
        return setSafeNetworkConfig()

    @logDecorator
    def udevTrigger(self, guid):
        self.__udevReloadRules(guid)
        cmd = [EXT_UDEVADM, 'trigger', '--verbose', '--action', 'change',
               '--property-match=DM_NAME=%s' % guid]
        rc, out, err = utils.execCmd(cmd)
        if rc:
            raise OSError(errno.EINVAL, "Could not trigger change for device \
                          %s, out %s\nerr %s" % (guid, out, err))

    @logDecorator
    def appropriateDevice(self, guid, thiefId):
        ruleFile = _UDEV_RULE_FILE_NAME % (guid, thiefId)
        # WARNING: we cannot use USER, GROUP and MODE since using any of them
        # will change the selinux label to the default, causing vms to pause.
        # See https://bugzilla.redhat.com/1147910
        rule = 'SYMLINK=="mapper/%s", RUN+="%s %s:%s $env{DEVNAME}"\n' % (
            guid, EXT_CHOWN, DISKIMAGE_USER, DISKIMAGE_GROUP)
        with open(ruleFile, "w") as rf:
            self.log.debug("Creating rule %s: %r", ruleFile, rule)
            rf.write(rule)

    @logDecorator
    def rmAppropriateRules(self, thiefId):
        re_apprDevRule = "^" + _UDEV_RULE_FILE_PREFIX + ".*?-" + thiefId + \
                         _UDEV_RULE_FILE_EXT + "$"
        rules = [os.path.join(_UDEV_RULE_FILE_DIR, r) for r in
                 os.listdir(_UDEV_RULE_FILE_DIR)
                 if re.match(re_apprDevRule, r)]
        fails = []
        for r in rules:
            try:
                self.log.debug("Removing rule %s", r)
                os.remove(r)
            except OSError:
                fails.append(r)
        return fails

    @logDecorator
    def appropriateIommuGroup(self, iommu_group):
        """
        Create udev rule in /etc/udev/rules.d/ to change ownership
        of /dev/vfio/$iommu_group to qemu:qemu. This method should be called
        when detaching a device from the host.
        """
        rule_file = _UDEV_RULE_FILE_NAME_VFIO % iommu_group

        if not os.path.isfile(rule_file):
            # If the file exists, different device from the same group has
            # already been detached and we therefore can skip overwriting the
            # file. Also, this file should only be created/removed via the
            # means of supervdsm.

            rule = ('KERNEL=="{}", SUBSYSTEM=="vfio" RUN+="{} {}:{} '
                    '/dev/vfio/{}"').format(iommu_group, EXT_CHOWN,
                                            QEMU_PROCESS_USER,
                                            QEMU_PROCESS_GROUP,
                                            iommu_group)

            with open(rule_file, "w") as rf:
                self.log.debug("Creating rule %s: %r", rule_file, rule)
                rf.write(rule)

            self.udevTrigger(iommu_group)

    @logDecorator
    def rmAppropriateIommuGroup(self, iommu_group):
        """
        Remove udev rule in /etc/udev/rules.d/ created by
        vfioAppropriateDevice.
        """
        rule_file = os.path.join(_UDEV_RULE_FILE_DIR, _UDEV_RULE_FILE_PREFIX +
                                 "iommu_group_" + iommu_group +
                                 _UDEV_RULE_FILE_EXT)
        error = False

        try:
            os.remove(rule_file)
        except OSError as e:
            if e.errno == errno.ENOENT:
                # OSError with ENOENT errno here means that the rule file does
                # not exist - this is expected when multiple devices in one
                # iommu group were passed through.
                error = True
            else:
                raise
        else:
            self.log.debug("Removing rule %s", rule_file)

        if not error:
            self.udevTrigger(iommu_group)

    @logDecorator
    def ksmTune(self, tuningParams):
        '''
        Set KSM tuning parameters for MOM, which runs without root privilege
        when it's lauched by vdsm. So it needs supervdsm's assistance to tune
        KSM's parameters.
        '''
        KSM_PARAMS = {'run': 3, 'merge_across_nodes': 3,
                      'sleep_millisecs': 0x100000000,
                      'pages_to_scan': 0x100000000}
        for (k, v) in tuningParams.iteritems():
            if k not in KSM_PARAMS.iterkeys():
                raise Exception('Invalid key in KSM parameter: %s=%s' % (k, v))
            if int(v) < 0 or int(v) >= KSM_PARAMS[k]:
                raise Exception('Invalid value in KSM parameter: %s=%s' %
                                (k, v))
            with open('/sys/kernel/mm/ksm/%s' % k, 'w') as f:
                f.write(str(v))

    @logDecorator
    def setPortMirroring(self, networkName, ifaceName):
        '''
        Copy networkName traffic of a bridge to an interface

        :param networkName: networkName bridge name to capture the traffic from
        :type networkName: string

        :param ifaceName: ifaceName to copy (mirror) the traffic to
        :type ifaceName: string

        this commands mirror all 'networkName' traffic to 'ifaceName'
        '''
        setPortMirroring(networkName, ifaceName)

    @logDecorator
    def unsetPortMirroring(self, networkName, target):
        '''
        Release captured mirror networkName traffic from networkName bridge

        :param networkName: networkName to release the traffic capture
        :type networkName: string
        :param target: target device to release
        :type target: string
        '''
        unsetPortMirroring(networkName, target)

    @logDecorator
    def mkFloppyFs(self, vmId, files, volId):
        return mkimage.mkFloppyFs(vmId, files, volId)

    @logDecorator
    def mkIsoFs(self, vmId, files, volId):
        return mkimage.mkIsoFs(vmId, files, volId)

    @logDecorator
    def removeFs(self, path):
        return mkimage.removeFs(path)

    @logDecorator
    def fuser(self, *args, **kwargs):
        return fuser.fuser(*args, **kwargs)

    @logDecorator
    def hbaRescan(self):
        return hba._rescan()

    @logDecorator
    def set_rp_filter_loose(self, dev):
        sysctl.set_rp_filter_loose(dev)

    @logDecorator
    def set_rp_filter_strict(self, dev):
        sysctl.set_rp_filter_strict(dev)

    def __udevReloadRules(self, guid):
        if self.__udevOperationReload():
            reload = "--reload"
        else:
            reload = "--reload-rules"
        cmd = [EXT_UDEVADM, 'control', reload]
        rc, out, err = utils.execCmd(cmd)
        if rc:
            self.log.error("Udevadm reload-rules command failed rc=%s, "
                           "out=\"%s\", err=\"%s\"", rc, out, err)
            raise OSError(errno.EINVAL, "Could not reload-rules for device "
                          "%s" % guid)

    @utils.memoized
    def __udevVersion(self):
        cmd = [EXT_UDEVADM, '--version']
        rc, out, err = utils.execCmd(cmd)
        if rc:
            self.log.error("Udevadm version command failed rc=%s, "
                           " out=\"%s\", err=\"%s\"", rc, out, err)
            raise RuntimeError("Could not get udev version number")
        return int(out[0])

    def __udevOperationReload(self):
        return self.__udevVersion() > self.UDEV_WITH_RELOAD_VERSION


def terminate(signo, frame):
    global _running
    _running = False


def main(sockfile, pidfile=None):
    log = logging.getLogger("SuperVdsm.Server")
    if not config.getboolean('vars', 'core_dump_enable'):
        resource.setrlimit(resource.RLIMIT_CORE, (0, 0))
    sigutils.register()
    zombiereaper.registerSignalHandler()

    def bind(func):
        def wrapper(_SuperVdsm, *args, **kwargs):
            return func(*args, **kwargs)
        return wrapper

    if _glusterEnabled:
        for name, func in listPublicFunctions(RHEV_ENABLED):
            setattr(_SuperVdsm, name, logDecorator(bind(func)))

    try:
        log.debug("Making sure I'm root - SuperVdsm")
        if os.geteuid() != 0:
            sys.exit(errno.EPERM)

        if pidfile:
            pid = str(os.getpid())
            with open(pidfile, 'w') as f:
                f.write(pid + "\n")

        log.debug("Parsing cmd args")
        address = sockfile

        log.debug("Cleaning old socket %s", address)
        if os.path.exists(address):
            os.unlink(address)

        log.debug("Setting up keep alive thread")

        try:
            signal.signal(signal.SIGTERM, terminate)
            signal.signal(signal.SIGINT, terminate)

            log.debug("Creating remote object manager")
            manager = _SuperVdsmManager(address=address, authkey='')
            manager.register('instance', callable=_SuperVdsm)

            server = manager.get_server()
            servThread = threading.Thread(target=server.serve_forever)
            servThread.setDaemon(True)
            servThread.start()

            chown(address, getpwnam(VDSM_USER).pw_uid, METADATA_GROUP)

            log.debug("Started serving super vdsm object")

            sourceroutethread.start()

            while _running:
                sigutils.wait_for_signal()

            log.debug("Terminated normally")
        finally:
            if os.path.exists(address):
                utils.rmFile(address)

    except Exception:
        log.error("Could not start Super Vdsm", exc_info=True)
        sys.exit(1)


def _usage():
    print "Usage:  supervdsmServer --sockfile=fullPath [--pidfile=fullPath]"


def _parse_args():
    argDict = {}
    opts, args = getopt.getopt(sys.argv[1:], "h", ["sockfile=", "pidfile="])
    for o, v in opts:
        o = o.lower()
        if o == "--sockfile":
            argDict['sockfile'] = v
        elif o == "--pidfile":
            argDict['pidfile'] = v
        else:
            _usage()
            sys.exit(1)
    if 'sockfile' not in argDict:
        _usage()
        sys.exit(1)

    return argDict


if __name__ == '__main__':
    argDict = _parse_args()
    main(**argDict)
