#!/usr/bin/env python
# Copyright 2016 Red Hat, Inc.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
#
# Refer to the README and COPYING files for full details of the license
#

import argparse
from contextlib import contextmanager
import itertools
import sys
import threading

from ovirt_imageio_common import directio

from vdsm import concurrent
from vdsm import libvirtconnection
from vdsm import utils
from vdsm.password import ProtectedPassword

_start = utils.monotonic_time()


class StreamAdapter(object):
    def __init__(self, stream):
        self.read = stream.recv


def arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument('args')
    parser.add_argument('--uri', dest='uri', required=True,
                        help='Libvirt URI')
    parser.add_argument('--username', dest='username', default='',
                        help='Libvirt login user name')
    parser.add_argument('--password-file', dest='password_file', default='',
                        help='Libvirt login password read from a file')
    parser.add_argument('--source', dest='source', nargs='+', required=True,
                        help='Source remote volumes path')
    parser.add_argument('--dest', dest='dest', nargs='+', required=True,
                        help='Destination local volumes path')
    parser.add_argument('--bufsize', dest='bufsize', default=1048576,
                        type=int, help='Size of packets in bytes, default'
                        '1048676')
    parser.add_argument('--verbose', action='store_true',
                        help='verbose output')
    return parser.parse_args(sys.argv)


def write_output(msg):
    sys.stdout.write('[%7.1f] %s\n' % (utils.monotonic_time() - _start, msg))
    sys.stdout.flush()


def write_error(e):
    write_output("ERROR: %s" % e)


def write_progress(progress):
    sys.stdout.write('    (%d/100%%)\r' % progress)
    sys.stdout.flush()


def volume_progress(op, done, estimated_size):
    while op.done < estimated_size:
        progress = min(99, op.done * 100 // estimated_size)
        write_progress(progress)
        if done.wait(1):
            break
    write_progress(100)


@contextmanager
def progress(op, estimated_size):
    done = threading.Event()
    th = concurrent.thread(volume_progress, args=(op, done, estimated_size))
    th.start()
    try:
        yield th
    finally:
        done.set()
        th.join()


def download_volume(con, vol, dest, bufsize):
    estimated_size = vol.info()[1]
    stream = con.newStream()
    vol.download(stream, 0, 0, 0)
    sr = StreamAdapter(stream)
    op = directio.Receive(dest, sr, buffersize=bufsize)
    with progress(op, estimated_size):
        op.run()
    stream.finish()


def get_password(options):
    if not options.password_file:
        return None
    if options.verbose:
        write_output('>>> Reading password from file %s' %
                     options.password_file)
    with open(options.password_file, 'r') as f:
        return ProtectedPassword(f.read())


def main():
    options = arguments()

    con = libvirtconnection.open_connection(options.uri,
                                            options.username,
                                            get_password(options))

    disk_count = len(options.source)
    bufsize = options.bufsize
    write_output('preparing for copy')
    disks = itertools.izip(options.source, options.dest)
    for diskno, (src, dst) in enumerate(disks, start=1):
        vol = con.storageVolLookupByPath(src)
        write_output('Copying disk %d/%d to %s' % (diskno, disk_count, dst))
        if options.verbose:
            write_output('>>> disk %d, capacity: %d allocation %d' %
                         (diskno, vol.info()[1], vol.info()[2]))
        download_volume(con, vol, dst, bufsize)
        diskno = diskno + 1
    write_output('Finishing off')

if __name__ == "__main__":
    main()
