import tensorflow as tf
from config import args
from models import *
from utils import *
from data import Mnist
import tensorbayes as tb
from itertools import izip
import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='3'

def main():
    log_file = make_file_name()
    print args

    def evaluate(x, y, xu, yu, eval_tensors, iw=1):
        if iw == 1:
            xs, ys, xus, yus = [x], [y], [xu], [yu]
        else:
            batches = 2000
            xs, ys = list(tb.nputils.split(x, batches)), list(tb.nputils.split(y, batches))
            xus, yus = list(tb.nputils.split(xu, batches)), list(tb.nputils.split(yu, batches))

        values = []
        for x, y, xu, yu in zip(xs, ys, xus, yus):
            feed_dict = {T.x: x, T.xu: xu, T.y: y, T.yu: yu, T.phase: 0, T.u: u, T.iw: iw}
            v = T.sess.run(eval_tensors, feed_dict)
            values += [v]

        values = [np.mean(v).astype(v[0].dtype) for v in zip(*values)]
        return values

    def train(T_train_step, T_loss, data, iterep, n_epochs):
        for i in xrange(iterep * n_epochs):
            x, y, xu, yu = data.next_batch(args.bs)
            feed_dict = {T.x: x, T.xu: xu, T.y: y, T.yu: yu, T.phase: 1, T.u: u, T.iw: 1}
            _, loss = T.sess.run([T_train_step, T_loss], feed_dict)

            message = "loss: {:.2e}".format(loss)
            end_epoch, epoch = tb.utils.progbar(i, iterep, message, bar_length=5)

            if np.isnan(loss):
                print "NaN detected"
                quit()

            if end_epoch:
                iw = 100 if epoch % args.n_checks == 0 else 1
                tr_values = evaluate(data.x_label, data.y_label, data.x_train,
                                     data.y_train, writer.tensors, iw=1)
                va_values = evaluate(data.x_valid, data.y_valid, data.x_valid,
                                     data.y_valid, writer.tensors[:-1], iw=iw)
                te_values = evaluate(data.x_test, data.y_test, data.x_test,
                                     data.y_test, writer.tensors[:-1], iw=iw)
                values = tr_values + va_values + te_values + [epoch]
                writer.write(values=values)

    def make_writer():
        # Make log file
        writer = tb.FileWriter(log_file, args=args, pipe_to_sys=True, overwrite=args.run >= 999)
        # Train log
        writer.add_var('train_iw', '{:4d}', T.iw)
        for v in ['bcde', 'bjde_x', 'bjde_xy', 'bjde_xu', 'bjde_yu', 'loss']:
            writer.add_var('train_{:s}'.format(v), '{:8.3f}', T[v])
        writer.add_var('l2_loss', '{:9.2e}', T.l2)

        # Validation log
        writer.add_var('valid_iw', '{:4d}')
        for v in ['bcde', 'bcde_x', 'bjde_xy', 'bjde_xu', 'bjde_yu', 'loss']:
            writer.add_var('valid_{:s}'.format(v), '{:8.3f}')

        # Test log
        writer.add_var('test_iw', '{:4d}')
        for v in ['bcde', 'bcde_x', 'bjde_xy', 'bjde_xu', 'bjde_yu', 'loss']:
            writer.add_var('test_{:s}'.format(v), '{:8.3f}')

        # Extra info
        writer.add_var('epoch', '{:>8d}')
        writer.initialize()
        return writer

    ###############
    # Build model #
    ###############
    tf.reset_default_graph()
    T = tb.utils.TensorDict(dict(
        bcde=constant(0),
        bjde_x=constant(0),
        bjde_xu=constant(0),
        bjde_yu=constant(0),
        bjde_xy=constant(0),
        l2=constant(0),
        loss=constant(0)))
    T.xu = placeholder((None, args.x_size), name='xu')
    T.yu = placeholder((None, args.y_size), name='yu')
    T.x = placeholder((None, args.x_size), name='x')
    T.y = placeholder((None, args.y_size), name='y')
    T.iw = placeholder(None, 'int32', name='iw') * 1 # hack for pholder eval
    T.u = placeholder(None, name='u')
    T.phase = placeholder(None, tf.bool, name='phase')

    if args.model == 'conditional':
        conditional(T)
    elif args.model in {'hybrid', 'hybrid_factored'}:
        hybrid(T)
    elif args.model == 'pretrained':
        pretrained(T)

    T.sess = tf.Session()
    T.sess.run(tf.global_variables_initializer())

    # Push all labeled data into unlabeled data set as well if using pretraining
    mnist = Mnist(args.n_label, args.seed, args.task, shift=args.shift,
                  duplicate='pretrain' in args.model, binarize=True)

    # Define remaining optimization hyperparameters
    if args.model == 'conditional':
        iterep = args.n_label / args.bs
        u = 1
    elif args.model in {'hybrid', 'hybrid_factored'}:
        iterep = args.n_total / args.bs
        u = 1 - args.n_label / float(args.n_total)
    elif args.model == 'pretrained':
        pretrain_iterep = args.n_total / args.bs
        iterep = args.n_label / args.bs
        u = 1

    # Sanity checks and creation of logger
    print "Data/Task statistics"
    print "Task:", args.task
    print "Data shapes of (x, y) for Labeled/Train/Valid/Test sets"
    print (mnist.x_label.shape, mnist.y_label.shape)
    print (mnist.x_train.shape, mnist.y_train.shape)
    print (mnist.x_valid.shape, mnist.y_valid.shape)
    print (mnist.x_test.shape, mnist.y_test.shape)
    writer = make_writer()

    ###############
    # Train model #
    ###############
    if 'pretrained' in args.model:
        print "Pretrain epochs, iterep", args.n_pretrain_epochs, pretrain_iterep
        train(T.pre_train_step, T.pre_loss, mnist, pretrain_iterep, args.n_pretrain_epochs)

    if 'hybrid' in args.model:
        print "Hybrid weighting on x_train and x_label:", (u, 1 - u)
    print "Epochs, Iterep", args.n_epochs, iterep
    train(T.train_step, T.loss, mnist, iterep, args.n_epochs)

if __name__ == '__main__':
    import warnings
    if '1.1.0' not in tf.__version__:
        warnings.warn("Library only tested in tf=1.1.0")
    main()