#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

# Usage1: hawq upgrade [-h hostname] [-p port] [-U username]

import os
import sys
try:
    from gppylib.commands.unix import getLocalHostname, getUserName
    from gppylib.db import dbconn
    from gppylib.gplog import get_default_logger, setup_tool_logging
    from gppylib.gpparseopts import OptParser, OptChecker
    from pygresql import pg
    from hawqpylib.hawqlib import local_ssh, local_ssh_output
except ImportError, e:
    print e
    sys.stderr.write('Cannot import module, please check that you have source greenplum_path.sh\n')
    sys.exit(2)

# setup logging
logger = get_default_logger()
EXECNAME = os.path.split(__file__)[-1]
setup_tool_logging(EXECNAME, getLocalHostname(), getUserName())

def option_parser():
    '''option parser'''
    parser = OptParser(option_class=OptChecker,
                       usage='usage: %prog [options] table_name',
                       version='%prog version $Revision: #1 $')
    parser.remove_option('-h')
    parser.add_option('-h', '--host', help='host of the target DB')
    parser.add_option('-p', '--port', help='port of the target DB', type='int', default=0)
    parser.add_option('-U', '--user', help='username of the target DB')
    return parser

def register_func(options, args):
    # connect db
    try:
        url = dbconn.DbURL(hostname=options.host, port=options.port,
                           dbname="template1", username=options.user)
        utility_conn = pg.connect(dbname=url.pgdb, host=url.pghost, port=url.pgport,
                                  user=url.pguser, passwd=url.pgpass, opt='-c gp_session_role=utility')


        # register function for template1
        create_func_query = """
            CREATE FUNCTION gp_relfile_insert_for_register(Oid, Oid, Oid, Oid, Oid, cstring, char, char, Oid)
            RETURNS int4 LANGUAGE internal VOLATILE AS 'gp_relfile_insert_for_register'
            WITH (DESCRIPTION="insert record into gp_relfile_insert_for_register and gp_relfile_node");
            """
        drop_func_query = """
            DROP FUNCTION IF EXISTS gp_relfile_insert_for_register
            (Oid, Oid, Oid, Oid, Oid, cstring, char, char, Oid);
            """

        utility_conn.query(drop_func_query)
        utility_conn.query(create_func_query)
        logger.info('Function gp_relfile_insert_for_register successfully registered into database template1.')
        # register function for other databases
        query = """select datname from pg_database;"""
        dbs = utility_conn.query(query).dictresult()
        for cur_db in dbs:
            if cur_db['datname'] not in ['template1', 'template0', 'hcatalog']:
                url = dbconn.DbURL(hostname=options.host, port=options.port,
                               dbname=cur_db['datname'], username=options.user)
                existed_db_conn = pg.connect(dbname=cur_db['datname'], host=url.pghost,
                                         port=url.pgport, user=url.pguser, passwd=url.pgpass)
                existed_db_conn.query(drop_func_query);
                existed_db_conn.query(create_func_query);
                logger.info('Function gp_relfile_insert_for_register successfully registered into database %s.',cur_db['datname'])
                existed_db_conn.close()
        utility_conn.close()
    except pg.DatabaseError as e:
        logger.error('%s', e)
        logger.info('Hawq Upgrade Failed.')
        sys.exit(1)
    except pg.InternalError:
        logger.error('Fail to connect to database, this script can only be run when database is up.')
        logger.info('Hawq Upgrade Failed.')
        sys.exit(1)

def pxf_get_item_fields_func(options, args):
    # connect db
    try:
        url = dbconn.DbURL(hostname=options.host, port=options.port,
                           dbname="template1", username=options.user)
        utility_conn = pg.connect(dbname=url.pgdb, host=url.pghost, port=url.pgport,
                                  user=url.pguser, passwd=url.pgpass, opt='-c gp_session_role=utility')


        # update function for template1
        update_func_query = """
            SET allow_system_table_mods = 'dml';
            UPDATE pg_proc
            SET proallargtypes = '{25,25,25,25,25,25,25}',  proargmodes = '{i,i,o,o,o,o,o}',  proargnames = '{profile,pattern,path,itemname,fieldname,fieldtype,sourcefieldtype}'
            WHERE proname = 'pxf_get_item_fields';
            """

        utility_conn.query(update_func_query)
        logger.info('Function pxf_get_item_fields successfully updated into database template1.')
         # pxf_get_item_fields function for other databases
        query = """select datname from pg_database;"""
        dbs = utility_conn.query(query).dictresult()
        for cur_db in dbs:
            if cur_db['datname'] not in ['template1', 'template0', 'hcatalog']:
                url = dbconn.DbURL(hostname=options.host, port=options.port,
                               dbname=cur_db['datname'], username=options.user)
                existed_db_conn = pg.connect(dbname=cur_db['datname'], host=url.pghost,
                                         port=url.pgport, user=url.pguser, passwd=url.pgpass)
                existed_db_conn.query(update_func_query);
                logger.info('Function pxf_get_item_fields successfully updated in database %s.',cur_db['datname'])
                existed_db_conn.close()
        utility_conn.close()
    except pg.DatabaseError as e:
        logger.error('%s', e)
        logger.info('Hawq Upgrade Failed.')
        sys.exit(1)
    except pg.InternalError:
        logger.error('Fail to connect to database, this script can only be run when database is up.')
        logger.info('Hawq Upgrade Failed.')
        sys.exit(1)

if __name__ == '__main__':
    parser = option_parser()
    options, args = parser.parse_args()
    if len(args) != 0:
        parser.print_help(sys.stderr)
        sys.exit(1)
    register_func(options, args)
    pxf_get_item_fields_func(options, args)
    logger.info('Hawq Upgrade Succeed.')