#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

# Usage1: hawq register [-h hostname] [-p port] [-U username] [-d database] [-f filepath] [-e eof] <tablename>
# Usage2: hawq register [-h hostname] [-p port] [-U username] [-d database] [-c config] [--force] <tablename>

import os
import sys
try:
    from gppylib.commands.unix import getLocalHostname, getUserName
    from gppylib.db import dbconn
    from gppylib.gplog import get_default_logger, setup_tool_logging
    from gppylib.gpparseopts import OptParser, OptChecker
    from pygresql import pg
    from hawqpylib.hawqlib import local_ssh, local_ssh_output
except ImportError, e:
    print e
    sys.stderr.write('Cannot import module, please check that you have source greenplum_path.sh\n')
    sys.exit(2)

# setup logging
logger = get_default_logger()
EXECNAME = os.path.split(__file__)[-1]

# print executing command
cmd = "Executing Command: ";
for arg in sys.argv:
    cmd += arg + " "
logger.info(cmd);

def option_parser():
    '''option parser'''
    parser = OptParser(option_class=OptChecker,
                       usage='usage: %prog [options] table_name',
                       version='%prog version $Revision: #1 $')
    parser.remove_option('-h')
    parser.add_option('-?', '--help', action='help')
    parser.add_option('-h', '--host', help='host of the target DB')
    parser.add_option('-p', '--port', help='port of the target DB', type='int', default=0)
    parser.add_option('-U', '--user', help='username of the target DB')
    parser.add_option('-d', '--database', default='postgres', dest='database', help='database name')
    parser.add_option('-f', '--filepath', dest='filepath', help='file name in HDFS')
    parser.add_option('-e', '--eof', dest='filesize', type='int', default=None, help='eof of the file to be registered')
    parser.add_option('-c', '--config', dest='yml_config', default='', help='configuration file in YAML format')
    parser.add_option('-F', '--force', dest='force', action='store_true', default=False)
    parser.add_option('-l', '--logdir', dest='logDir', help="Sets the directory for log files")
    return parser


def register_yaml_dict_check(D, table_column_num, src_tablename):
    '''check exists'''
    check_list = ['DFS_URL', 'Distribution_Policy', 'FileFormat', 'TableName']
    for attr in check_list:
        if D.get(attr) is None:
            logger.error('Wrong configuration yaml file format: "%s" attribute does not exist.\n See example in "hawq register --help".' % attr)
            sys.exit(1)
    if D['Distribution_Policy'].startswith('DISTRIBUTED BY'):
        if D.get('Bucketnum') is None:
            logger.error('Wrong configuration yaml file format: "%s" attribute does not exist.\n See example in "hawq register --help".' % attr)
            sys.exit(1)
        if D['Bucketnum'] <= 0:
            logger.error('Bucketnum should not be zero, please check your yaml configuration file.')
            sys.exit(1)
    if D['FileFormat'] in ['Parquet', 'AO']:
        prefix = D['FileFormat']
        local_check_list = ['%s_FileLocations' % prefix, '%s_Schema' % prefix]
        for attr in local_check_list:
            if D.get(attr) is None:
                logger.error('Wrong configuration yaml file format: "%s" attribute does not exist.\n See example in "hawq register --help".' % attr)
                sys.exit(1)
        if D['%s_FileLocations' % prefix].get('Files') is None:
            logger.error('Wrong configuration yaml file format: "%s" attribute does not exist.\n See example in "hawq register --help".' % '%s_FileLocations.Files' % prefix)
            sys.exit(1)
        for d in D['%s_FileLocations' % prefix]['Files']:
            if d.get('path') is None:
                logger.error('Wrong configuration yaml file format: "%s" attribute does not exist.\n See example in "hawq register --help".' % '%s_FileLocations.Files.path' % prefix)
                sys.exit(1)
            if d.get('size') is None:
                logger.error('Wrong configuration yaml file format: "%s" attribute does not exist.\n See example in "hawq register --help".' % '%s_FileLocations.Files.size' % prefix)
                sys.exit(1)
    else:
        logger.error('hawq register only support Parquet and AO formats. Format %s is not supported.' % D['FileFormat'])
        sys.exit(1)
    prefix = D['FileFormat']
    if D.get('%s_Schema' % prefix) is None:
        logger.error('Wrong configuration yaml file format: "%s" attribute does not exist.\n See example in "hawq register --help".' % '%s_Schema' % prefix)
        sys.exit(1)
    for d in D['%s_Schema' % prefix]:
        if d.get('name') is None:
            logger.error('Wrong configuration yaml file format: "%s" attribute does not exist.\n See example in "hawq register --help".' % '%s_Schema.name' % prefix)
            sys.exit(1)
        if d.get('type') is None:
            logger.error('Wrong configuration yaml file format: "%s" attribute does not exist.\n See example in "hawq register --help".' % '%s_Schema.type' % prefix)
            sys.exit(1)
    if D['FileFormat'] == 'Parquet':
        sub_check_list = ['CompressionLevel', 'CompressionType', 'PageSize', 'RowGroupSize']
        for attr in sub_check_list:
            if not D['Parquet_FileLocations'].has_key(attr):
                logger.error('Wrong configuration yaml file format: "%s" attribute does not exist.\n See example in "hawq register --help".' % 'Parquet_FileLocations.%s' % attr)
                sys.exit(1)
    else:
        sub_check_list = ['Checksum', 'CompressionLevel', 'CompressionType']
        for attr in sub_check_list:
            if not D['AO_FileLocations'].has_key(attr):
                logger.error('Wrong configuration yaml file format: "%s" attribute does not exist.\n See example in "hawq register --help".' % 'AO_FileLocations.%s' % attr)
                sys.exit(1)
    if D['FileFormat'].lower() == 'parquet':
        yml_column_num = len(D['Parquet_Schema'])
    else:
        yml_column_num = len(D['AO_Schema'])
    if table_column_num != yml_column_num and table_column_num > 0:
        logger.error('Column number of table in yaml file is not equals to the column number of table %s.' % src_tablename)
        sys.exit(1)

def ispartition(yml_file):
    import yaml
    try:
        with open(yml_file, 'r') as f:
            params = yaml.load(f)
    except yaml.scanner.ScannerError as e:
        print e
        sys.exit(1)

    if params['FileFormat'].lower() == 'parquet':
        Format = 'Parquet'
    else: #AO format
        Format = 'AO'
    Format_FileLocations = '%s_FileLocations' % Format
    if params.get(Format_FileLocations):
        partitionby = params.get(Format_FileLocations).get('PartitionBy')
        if partitionby:
            return True
    return False

def tablename_handler(tablename):
    tmp_lst = tablename.split('.')
    if len(tmp_lst) > 1:
        return tmp_lst[-2], tmp_lst[-1]
    return 'public', tablename

def check_file_exist(yml_file):
    if yml_file:
        if not os.path.exists(yml_file):
            logger.error('Cannot find yaml file : %s' % yml_file)
            sys.exit(1)
        return True
    return False

class FailureHandler(object):
    def __init__(self, conn):
        self.operations = []
        self.conn = conn

    def commit(self, cmd):
        self.operations.append(cmd)

    def assemble_SQL(self, cmd):
        return 'DROP TABLE %s' % cmd[cmd.find('table')+6:cmd.find('(')]

    def assemble_hdfscmd(self, cmd):
        lst = cmd.strip().split()
        return ' '.join(lst[:-2] + [lst[-1], lst[-2]])

    def rollback(self):
        if len(self.operations) != 0:
            logger.info('Error found, Hawqregister starts to rollback...')
        for (typ, cmd) in reversed(self.operations):
            if typ == 'SQL':
                sql = self.assemble_SQL(cmd)
                try:
                    self.conn.query(sql)
                except pg.DatabaseError as e:
                    logger.error('Rollback failure: %s.' % sql)
                    print e
                    sys.exit(1)
            if typ == 'HDFSCMD':
                hdfscmd = self.assemble_hdfscmd(cmd)
                sys.stdout.write('Rollback hdfscmd: "%s"\n' % hdfscmd)
                result = local_ssh(hdfscmd, logger)
                if result != 0:
                    logger.error('Fail to rollback: %s.' % hdfscmd)
                    sys.exit(1)
        if len(self.operations) != 0:
            logger.info('Hawq Register Rollback Finished.')


class GpRegisterAccessor(object):
    def __init__(self, conn):
        self.conn = conn
        rows = self.exec_query("""
        SELECT oid, datname, dat2tablespace,
               pg_encoding_to_char(encoding) encoding
        FROM pg_database WHERE datname=current_database()""")
        self.dbid = rows[0]['oid']
        self.dbname = rows[0]['datname']
        self.spcid = rows[0]['dat2tablespace']
        self.dbencoding = rows[0]['encoding']
        self.dbversion = self.exec_query('select version()')[0]['version']

    def exec_query(self, sql):
        '''execute query and return dict result'''
        return self.conn.query(sql).dictresult()

    def get_schema_id(self, schemaname):
        qry = """select oid from pg_namespace where nspname='%s';""" % schemaname
        return self.exec_query(qry)[0]['oid']

    def get_table_existed(self, tablename):
        schemaname, tablename = tablename_handler(tablename)
        qry = """select count(*) from pg_class where relnamespace = '%s' and relname = '%s';""" % (self.get_schema_id(schemaname), tablename)
        return self.exec_query(qry)[0]['count'] == 1

    def get_table_column_num(self, tablename):
        schemaname, tablename = tablename_handler(tablename)
        qry = """select count(*) from pg_attribute ,pg_class where pg_class.relnamespace = '%s' and pg_class.relname = '%s' and pg_class.oid = pg_attribute.attrelid and attnum > 0;""" % (self.get_schema_id(schemaname), tablename)
        return self.exec_query(qry)[0]['count']

    def do_create_table(self, src_table_name, tablename, schema_info, fmt, distrbution_policy, file_locations, bucket_number, partitionby, partitions_constraint, partitions_name):
        if self.get_table_existed(tablename):
            return False, ''
        schema = ','.join([k['name'] + ' ' + k['type'] for k in schema_info])
        partlist = ""
        for index in range(len(partitions_constraint)):
            if index > 0:
                partlist += ", "
            partition_refine_name = partitions_name[index]
            splitter = src_table_name.split(".")[-1] + '_1_prt_'
            partition_refine_name = partition_refine_name.split(splitter)[-1]
            #in some case, constraint contains "partition XXX" but in other case, it doesn't contain. we need to treat them separately.
            if partitions_constraint[index].strip().startswith("DEFAULT PARTITION") or partitions_constraint[index].strip().startswith("PARTITION") or (len(partition_refine_name) > 0 and partition_refine_name[0].isdigit()):
                partlist = partlist + " " + partitions_constraint[index]
            else:
                partlist = partlist + "PARTITION " + partition_refine_name + " " + partitions_constraint[index]

        bucket_number_policy = ', bucketnum=%s)' % bucket_number if distrbution_policy != 'DISTRIBUTED RANDOMLY' else ')'
        fmt = 'ROW' if fmt == 'AO' else fmt
        if fmt == 'ROW':
            if partitionby is None:
                query = ('create table %s(%s) with (appendonly=true, orientation=%s, compresstype=%s, compresslevel=%s, checksum=%s%s %s;'
                         % (tablename, schema, fmt, file_locations['CompressionType'], file_locations['CompressionLevel'], file_locations['Checksum'], bucket_number_policy, distrbution_policy))
            else:
                query = ('create table %s(%s) with (appendonly=true, orientation=%s, compresstype=%s, compresslevel=%s, checksum=%s%s %s %s (%s);'
                         % (tablename, schema, fmt, file_locations['CompressionType'], file_locations['CompressionLevel'], file_locations['Checksum'], bucket_number_policy, distrbution_policy, partitionby, partlist))
        else: # Parquet
            if partitionby is None:
                query = ('create table %s(%s) with (appendonly=true, orientation=%s, compresstype=%s, compresslevel=%s, pagesize=%s, rowgroupsize=%s%s %s;'
                         % (tablename, schema, fmt, file_locations['CompressionType'], file_locations['CompressionLevel'], file_locations['PageSize'], file_locations['RowGroupSize'], bucket_number_policy, distrbution_policy))
            else:
                query = ('create table %s(%s) with (appendonly=true, orientation=%s, compresstype=%s, compresslevel=%s, pagesize=%s, rowgroupsize=%s%s %s %s (%s);'
                         % (tablename, schema, fmt, file_locations['CompressionType'], file_locations['CompressionLevel'], file_locations['PageSize'], file_locations['RowGroupSize'], bucket_number_policy, distrbution_policy, partitionby, partlist))
        self.conn.query(query)
        return True, query

    def is_hash_distributed(self, tablename):
        schemaname, tablename = tablename_handler(tablename)
        qry = """select attrnums from gp_distribution_policy, pg_class where pg_class.relnamespace = '%s' and pg_class.relname = '%s' and pg_class.oid = gp_distribution_policy.localoid;""" % (self.get_schema_id(schemaname), tablename)
        rows = self.exec_query(qry)
        if rows[0]['attrnums']:
            return True
        return False

    def check_hash_type(self, tablename):
        schemaname, tablename = tablename_handler(tablename)
        qry = """select attrnums from gp_distribution_policy, pg_class where pg_class.relnamespace = '%s' and pg_class.relname = '%s' and pg_class.oid = gp_distribution_policy.localoid;""" % (self.get_schema_id(schemaname), tablename)
        rows = self.exec_query(qry)
        if len(rows) == 0:
            logger.error('Table %s is not an append-only table. There is no record in gp_distribution_policy table.' % tablename)
            sys.exit(1)
        if rows[0]['attrnums']:
            logger.error('Cannot register file(s) to a table which is hash distributed.')
            sys.exit(1)

    # pg_paqseg_#
    def get_seg_name(self, tablename, database, fmt):
        schemaname, tablename = tablename_handler(tablename)
        query = ("select pg_class2.relname from pg_class as pg_class1, pg_appendonly, pg_class as pg_class2 "
                 "where pg_class1.relname ='%s' and pg_class1.oid = pg_appendonly.relid and pg_appendonly.segrelid = pg_class2.oid and pg_class1.relnamespace = '%s';") % (tablename, self.get_schema_id(schemaname))
        rows = self.exec_query(query)
        if len(rows) == 0:
            logger.error('table "%s" not found in db "%s"' % (tablename, database))
            return ('', False)
        relname = rows[0]['relname']
        if fmt == 'Parquet':
            if relname.find('paq') == -1:
                logger.error("table '%s' is not parquet format" % tablename)
                return ('', False)
        return (relname, True)

    def get_distribution_policy_info(self, tablename):
        schemaname, tablename = tablename_handler(tablename)
        query = "select oid from pg_class where relnamespace = '%s' and relname = '%s';" % (self.get_schema_id(schemaname), tablename)
        rows = self.exec_query(query)
        oid = rows[0]['oid']
        query = "select * from gp_distribution_policy where localoid = '%s';" % oid
        rows = self.exec_query(query)
        return rows[0]['attrnums']

    def get_partition_info(self, tablename):
        ''' Get partition information from pg_partitions, return a constraint-tablename dictionary '''
        schemaname, tablename = tablename_handler(tablename)
        query = "SELECT partitiontablename, partitionboundary FROM pg_partitions WHERE partitionschemaname = '%s' and tablename = '%s'" % (schemaname, tablename)
        return self.exec_query(query)

    def get_partition_parent(self, tablename):
        schemaname, tablename = tablename_handler(tablename)
        query = "SELECT partitiontablename, parentpartitiontablename FROM pg_partitions WHERE partitionschemaname = '%s' and tablename = '%s'" % (schemaname, tablename)
        return self.exec_query(query)

    def get_partitionby(self, tablename):
        schemaname, tablename = tablename_handler(tablename)
        query = "SELECT partitionschemaname, partitiontablename, partitionname, partitiontype, parentpartitiontablename, partitionboundary FROM pg_partitions WHERE schemaname = '%s' and tablename='%s';" % (schemaname, tablename)
        parition_type = self.exec_query(query)[0]['partitiontype']
        query = "SELECT columnname, partitionlevel FROM pg_partition_columns WHERE schemaname = '%s' and tablename='%s' ORDER BY position_in_partition_key;" % (schemaname, tablename)
        partition_columnname = self.exec_query(query)[0]['columnname']
        partitionby = 'PARTITION BY %s (%s)' % (parition_type, partition_columnname)
        return partitionby

    def get_partition_num(self, tablename):
        schemaname, tablename = tablename_handler(tablename)
        query = "SELECT partitionschemaname from pg_partitions WHERE schemaname = '%s' and tablename='%s';" % (schemaname, tablename)
        return len(self.exec_query(query))

    def get_bucket_number(self, tablename):
        schemaname, tablename = tablename_handler(tablename)
        query = "select oid from pg_class where relnamespace = '%s' and relname = '%s';" % (self.get_schema_id(schemaname), tablename)
        rows = self.exec_query(query)
        oid = rows[0]['oid']
        query = "select * from gp_distribution_policy where localoid = '%s';" % oid
        rows = self.exec_query(query)
        return rows[0]['bucketnum']

    def get_metadata_from_database(self, tablename, seg_name):
        schemaname, tablename = tablename_handler(tablename) 
        query = 'select segno from pg_aoseg.%s;' % seg_name
        firstsegno = len(self.exec_query(query)) + 1
        # get the full path of correspoding file for target table
        query = ("select location, gp_persistent_tablespace_node.tablespace_oid, database_oid, relfilenode from pg_class, gp_persistent_relation_node, "
                 "gp_persistent_tablespace_node, gp_persistent_filespace_node where relnamespace = '%s' and relname = '%s' and pg_class.relfilenode = "
                 "gp_persistent_relation_node.relfilenode_oid and gp_persistent_relation_node.tablespace_oid = gp_persistent_tablespace_node.tablespace_oid "
                 "and gp_persistent_filespace_node.filespace_oid = gp_persistent_filespace_node.filespace_oid;") % (self.get_schema_id(schemaname), tablename)
        D = self.exec_query(query)[0]
        tabledir = '/'.join([D['location'].strip(), str(D['tablespace_oid']), str(D['database_oid']), str(D['relfilenode']), ''])
        return firstsegno, tabledir

    def get_metadata_from_seg_name(self, seg_name):
        query = 'select segno, eof from pg_aoseg.%s;' % seg_name
        rows = self.exec_query(query)
        return [str(row['segno']) for row in rows], [int(row['eof']) for row in rows]

    def get_database_encoding_indx(self, database):
        query = "select encoding from pg_database where datname = '%s';" % database
        return self.exec_query(query)[0]['encoding']

    def get_database_encoding(self, encoding_indx):
        query = "select pg_encoding_to_char(%s);" % encoding_indx
        return self.exec_query(query)[0]['pg_encoding_to_char']

    def get_metadata_for_relfile_insert(self, database, tablename):
        schemaname, tablename = tablename_handler(tablename)
        query = "select reltablespace from pg_class where relnamespace = '%s' and relname = '%s';" % (self.get_schema_id(schemaname), tablename)
        tablespace_oid = int(self.exec_query(query)[0]['reltablespace'])
        if tablespace_oid == 0:
            query = "select oid from pg_tablespace where spcname='dfs_default';"
            tablespace_oid = int(self.exec_query(query)[0]['oid'])
        query = "select oid from pg_database where datname='%s';" % database
        database_oid = int(self.exec_query(query)[0]['oid'])
        query = "select oid, relfilenode, relname, relkind, relstorage, relam from pg_class where relnamespace = '%s' and relname='%s';" % (self.get_schema_id(schemaname), tablename)
        return [tablespace_oid, database_oid, int(self.exec_query(query)[0]['oid']), int(self.exec_query(query)[0]['relfilenode']), str(self.exec_query(query)[0]['relname']), str(self.exec_query(query)[0]['relkind']), str(self.exec_query(query)[0]['relstorage']), int(self.exec_query(query)[0]['relam'])]

    def update_catalog(self, query):
        self.conn.query(query)


class HawqRegister(object):
    def __init__(self, options, table, utility_conn, conn, failure_handler):
        self.yml = options.yml_config
        self.filepath = options.filepath
        self.database = options.database
        self.dst_table_name = table.lower()
        self.tablename = table.lower()
        self.filesize = options.filesize
        self.accessor = GpRegisterAccessor(conn)
        self.utility_accessor = GpRegisterAccessor(utility_conn)
        self.failure_handler = failure_handler
        self.mode = self._init_mode(options.force)
        self.srcfiles = []
        self.dstfiles = []
        self.files_same_path = []
        self.sizes_same_path = []
        self.segnos_same_path = []
        self.tupcounts_same_path = []
        self.varblockcounts_same_path = []
        self.eofuncompresseds_same_path = []
        self.segnos_same_path = []

    def _init_mode(self, force):
        def table_existed():
            return self.accessor.get_table_existed(self.dst_table_name)

        if self.yml:
            if force:
                return 'force'
            else:
                return 'usage2_table_not_exist'
        else:
            if not table_existed():
                logger.error('Table %s does not exist.\nYou should create table before registering the data.' % self.dst_table_name)
                sys.exit(1)
            else:
                return 'usage1'

    def _is_hash_distributed(self):
        return self.accessor.is_hash_distributed(self.dst_table_name)

    def _check_hash_type(self):
        self.accessor.check_hash_type(self.dst_table_name)

    def _create_table(self):
        try:
           (ret, query) = self.accessor.do_create_table(self.src_table_name, self.dst_table_name, self.schema, self.file_format, self.distribution_policy, self.file_locations, self.bucket_number,
                                                             self.partitionby, self.partitions_constraint, self.partitions_name)
        except pg.DatabaseError as e:
            print e
            sys.exit(1)
        if ret:
            self.failure_handler.commit(('SQL', query))
        return ret

    def _check_database_encoding(self):
        encoding_indx = self.accessor.get_database_encoding_indx(self.database)
        encoding = self.accessor.get_database_encoding(encoding_indx)
        if self.encoding.strip() != encoding:
            logger.error('Database encoding from yaml configuration file(%s) is not consistent with encoding from input args(%s).' % (self.encoding, encoding))
            sys.exit(1)

    def _check_policy_consistency(self):
        policy = self._get_distribution_policy() # "" or "{1,3}"
        if policy is None:
            return
        if self.distribution_policy == 'DISTRIBUTED RANDOMLY':
            logger.error('Distribution policy of %s from yaml is not consistent with the policy of existing table.' % self.tablename)
            self.failure_handler.rollback()
            sys.exit(1)
        tmp_dict = {}
        for i, d in enumerate(self.schema):
            tmp_dict[d['name']] = i + 1
        # 'DISTRIBUETD BY (1,3)' -> {1,3}
        cols = self.distribution_policy.strip().split()[-1].strip('(').strip(')').split(',')
        original_policy = ','.join([str(tmp_dict[col]) for col in cols])
        if policy.strip('{').strip('}') != original_policy:
            logger.error('Distribution policy of %s from yaml file is not consistent with the policy of existing table.' % self.dst_table_name)
            self.failure_handler.rollback()
            sys.exit(1)

    def _get_metadata_for_relfile_insert(self):
        return self.accessor.get_metadata_for_relfile_insert(self.database, self.tablename)

    def _set_yml_data(self, file_format, files, sizes, tupcounts, eofuncompresseds, varblockcounts, tablename, schema, distribution_policy, file_locations,\
                      bucket_number, partitionby, partitions_constraint, partitions_name, partitions_compression_level,\
                      partitions_compression_type, partitions_checksum, partitions_filepaths, partitions_filesizes, \
                      partitions_tupcounts, partitions_eofuncompresseds, partitions_varblockcounts, encoding):
        self.file_format = file_format
        self.files = files
        self.sizes = sizes
        self.tupcounts = tupcounts
        self.eofuncompresseds = eofuncompresseds
        self.varblockcounts = varblockcounts
        self.src_table_name = tablename
        self.schema = schema
        self.distribution_policy = distribution_policy
        self.file_locations = file_locations
        self.bucket_number = bucket_number
        self.partitionby = partitionby
        self.partitions_constraint = partitions_constraint
        self.partitions_name = partitions_name
        self.partitions_compression_level = partitions_compression_level
        self.partitions_compression_type = partitions_compression_type
        self.partitions_checksum = partitions_checksum
        self.partitions_filepaths = partitions_filepaths
        self.partitions_filesizes = partitions_filesizes
        self.partitions_tupcounts = partitions_tupcounts
        self.partitions_eofuncompresseds = partitions_eofuncompresseds
        self.partitions_varblockcounts = partitions_varblockcounts
        self.encoding = encoding

    def _option_parser_yml(self, yml_file):
        import yaml
        try:
            with open(yml_file, 'r') as f:
                params = yaml.load(f)
        except yaml.scanner.ScannerError as e:
            print e
            logger.error('In _option_parser_yml(): %s' % e)
            self.failure_handler.rollback()
            sys.exit(1)
            
        table_column_num = self.accessor.get_table_column_num(self.tablename)
        register_yaml_dict_check(params, table_column_num, self.tablename)
        partitions_filepaths = []
        partitions_filesizes = []
        partitions_constraint = []
        partitions_name = []
        partitions_checksum = []
        partitions_compression_level = []
        partitions_compression_type = []
        partitions_tupcounts = []
        partitions_eofuncompresseds = []
        partitions_varblockcounts = []
        files, sizes, tupcounts, eofuncompresseds, varblockcounts = [], [], [], [], []

        try:
            if params['FileFormat'].lower() == 'parquet':
                Format = 'Parquet'
            else: #AO format
                Format = 'AO'
            Format_FileLocations = '%s_FileLocations' % Format
            partitionby = params.get(Format_FileLocations).get('PartitionBy')
            if params.get(Format_FileLocations).get('Partitions') and len(params[Format_FileLocations]['Partitions']):
                partitions_checksum = [d['Checksum'] for d in params[Format_FileLocations]['Partitions']]
                partitions_compression_level = [d['CompressionLevel'] for d in params[Format_FileLocations]['Partitions']]
                partitions_compression_type = [d['CompressionType'] for d in params[Format_FileLocations]['Partitions']]
                partitions_constraint = [d['Constraint'] for d in params[Format_FileLocations]['Partitions']]
                partitions_files = [d['Files'] for d in params[Format_FileLocations]['Partitions']]
                if len(partitions_files):
                    for pfile in partitions_files:
                        partitions_filepaths.append([params['DFS_URL'] + item['path'] for item in pfile])
                        partitions_filesizes.append([item['size'] for item in pfile])
                        partitions_tupcounts.append([item['tupcount'] if item.has_key('tupcount') else -1 for item in pfile])
                        partitions_eofuncompresseds.append([item['eofuncompressed'] if item.has_key('eofuncompressed') else -1 for item in pfile])
                        partitions_varblockcounts.append([item['varblockcount'] if item.has_key('varblockcount') else -1 for item in pfile])
                partitions_name = [d['Name'] for d in params[Format_FileLocations]['Partitions']]
            if len(params[Format_FileLocations]['Files']):
                for ele in params[Format_FileLocations]['Files']:
                    files.append(params['DFS_URL'] + ele['path'])
                    sizes.append(ele['size'])
                    tupcounts.append(ele['tupcount'] if ele.has_key('tupcount') else -1)
                    eofuncompresseds.append(ele['eofuncompressed'] if ele.has_key('eofuncompressed') else -1)
                    varblockcounts.append(ele['varblockcount'] if ele.has_key('varblockcount') else -1)

            encoding = params['Encoding']
            bucketNum = params['Bucketnum'] if params['Distribution_Policy'].startswith('DISTRIBUTED BY') else 6
            self._set_yml_data(Format, files, sizes, tupcounts, eofuncompresseds, varblockcounts, params['TableName'], params['%s_Schema' % Format], params['Distribution_Policy'], params[Format_FileLocations], bucketNum, partitionby, partitions_constraint, partitions_name, partitions_compression_level, partitions_compression_type, partitions_checksum, partitions_filepaths, partitions_filesizes, partitions_tupcounts, partitions_eofuncompresseds, partitions_varblockcounts, encoding)

        except KeyError as e:
            logger.error('Invalid yaml file, %s is missing.' % e)
            self.failure_handler.rollback()
            sys.exit(1)


    # check conflicting distributed policy
    def _check_distribution_policy(self):
        if self.distribution_policy.startswith('DISTRIBUTED BY'):
            if len(self.files) % self.bucket_number != 0:
                logger.error('Files to be registered must be multiple times to the bucket number of hash table.')
                self.failure_handler.rollback()
                sys.exit(1)

    def _get_seg_name(self):
        return self.utility_accessor.get_seg_name(self.tablename, self.database, self.file_format)

    def _get_metadata(self):
        return self.accessor.get_metadata_from_database(self.tablename, self.seg_name)

    def _get_metadata_from_table(self):
        return self.accessor.get_metadata_from_seg_name(self.seg_name)

    def _get_distribution_policy(self):
        return self.accessor.get_distribution_policy_info(self.tablename)

    def _check_bucket_number(self):
        def get_bucket_number():
            return self.accessor.get_bucket_number(self.tablename)

        if self.bucket_number != get_bucket_number():
            logger.error('Bucket number of %s is not consistent with previous bucket number.' % self.tablename)
            self.failure_handler.rollback()
            sys.exit(1)

    def _check_file_not_folder(self, pn=''):
        if pn:
            logger.info('Files check for table %s...' % pn)
        else:
            logger.info('Files check...')
        for fn in self.files:
            hdfscmd = 'hadoop fs -test -f %s' % fn
            if local_ssh(hdfscmd, logger):
                logger.info('%s is not a file in hdfs, please check the yaml configuration file.' % fn)
                sys.exit(1)
        if pn:
            logger.info('Files check done for table %s.' % pn)
        else:
            logger.info('Files check done.')

    def _is_folder(self, filepath):
        hdfscmd = 'hadoop fs -test -d %s' % filepath
        if local_ssh(hdfscmd, logger):
            return False
        else:
            return True

    def _check_sizes_valid(self):
        for sz in self.sizes:
            if type(sz) != type(1):
                logger.error('File size(%s) in yaml configuration file should be int type.' % sz)
                self.failure_handler.rollback()
                sys.exit(1)
            if sz < 0:
                logger.error('File size(%s) in yaml configuration file should not be less than 0.' % sz)
                self.failure_handler.rollback()
                sys.exit(1)
        hdfscmd = 'hadoop fs -du %s' % ' '.join(self.files)
        _, outs, _ = local_ssh_output(hdfscmd)
        outs = outs.split('\n')
        for k, out in enumerate(outs):
            if self.sizes[k] > int(out.strip().split()[0]):
                if self.mode == 'usage1':
                    logger.error('Specified file size(%s) should not exceed actual length(%s) of file %s.' % (self.sizes[k], out.strip().split()[0], self.files[k]))
                else:
                    logger.error('File size(%s) in yaml configuration file should not exceed actual length(%s) of file %s.' % (self.sizes[k], out.strip().split()[0], self.files[k]))
                self.failure_handler.rollback()
                sys.exit(1)

    def _check_no_regex_filepath(self, files):
        for fn in files:
            tmp_lst = fn.split('/')
            for v in tmp_lst:
                if v == '.':
                    logger.error('Hawq register does not support file path with regex: %s.' % fn)
                    self.failure_handler.rollback()
                    sys.exit(1)
            for ch in ['..', '*']:
                if fn.find(ch) != -1:
                    logger.error('Hawq register does not support file path with regex: %s.' % fn)
                    self.failure_handler.rollback()
                    sys.exit(1)

    def prepare(self):
        if self.yml:
            self._option_parser_yml(options.yml_config)
            self.filepath = self.files[0][:self.files[0].rfind('/')] if self.files else ''
            self._check_file_not_folder()
            self._check_database_encoding()
            if not self._create_table() and self.mode != 'force':
                self.mode = 'usage2_table_exist'
        else:
            if self._is_folder(self.filepath) and self.filesize:
                logger.error('-e option is only supported with single file case.')
                sys.exit(1)
            self.file_format = 'Parquet'
            self._check_hash_type() # Usage1 only support randomly distributed table
        self.queries = "set allow_system_table_mods='dml';"
        self.queries += "begin transaction;"
        self._do_check()
        self._prepare_register()
        self.queries += "end transaction;"

    def _do_check(self):
        if self.yml:
            if self._is_hash_distributed():
                self._check_bucket_number()
            self._check_distribution_policy()
            self._check_policy_consistency()
            self._check_no_regex_filepath(self.files)
        if not self.filepath:
            if self.mode == 'usage1':
                logger.info('Please specify filepath with -f option.')
            else:
                logger.info('Hawq Register Succeed.')
            sys.exit(0)

        (self.seg_name, tmp_ret) = self._get_seg_name()
        if not tmp_ret:
            logger.error('Failed to get segment name')
            self.failure_handler.rollback()
            sys.exit(1)
        self.firstsegno, self.tabledir = self._get_metadata()
        
        if self.yml and self.mode == 'force':
            existed_files, existed_sizes = self._get_files_in_hdfs(self.tabledir)
        else:
            existed_files, existed_sizes = self._get_files_in_hdfs(self.filepath)
        # check if file numbers in hdfs is consistent with the record count of pg_aoseg.
        hdfs_file_no_lst = [f.split('/')[-1] for f in existed_files]
        for k in range(1, self.firstsegno - 1):
            if self.firstsegno - 1 > len(existed_files) or str(k) not in hdfs_file_no_lst:
                logger.error("Hawq aoseg metadata doesn't consistent with file numbers in hdfs.")
                self.failure_handler.rollback()
                sys.exit(1)

        if self.mode == 'usage2_table_exist':
            if self.tabledir.strip('/') == self.filepath.strip('/'):
                logger.error('Files to be registered should not be the same with table path.')
                self.failure_handler.rollback()
                sys.exit(1)

        if not self.yml:
            self._check_no_regex_filepath([self.filepath])
            self.files, self.sizes = self._get_files_in_hdfs(self.filepath)
            self.tupcounts = self.eofuncompresseds = self.varblockcounts = [-1 for i in range(len(self.files))]

        self.do_not_move, self.files_update, self.sizes_update, self.tupcounts_update, self.eofuncompresseds_update, self.varblockcounts_update = False, [], [], [], [], []
        self.files_append, self.sizes_append, self.tupcounts_append, self.eofuncompresseds_append, self.varblockcounts_append = [f for f in self.files], [sz for sz in self.sizes], [tc for tc in self.tupcounts], [eof for eof in self.eofuncompresseds], [v for v in self.varblockcounts]
        if self.mode == 'force':
            if len(self.files) == len(existed_files):
                if sorted(self.files) != sorted(existed_files):
                    logger.error('In force mode, you should include existing table files in yaml configuration file. Otherwise you should drop the previous table before register --force.')
                    self.failure_handler.rollback()
                    sys.exit(1)
                else:
                    self.do_not_move, self.files_update, self.sizes_update, self.tupcounts_update, self.eofuncompresseds_update, self.varblockcounts_update = True, self.files, self.sizes, self.tupcounts, self.eofuncompresseds, self.varblockcounts
                self.files_append, self.sizes_append, self.tupcounts_append, self.eofuncompresseds_append, self.varblockcounts_append = [], [], [], [], []
            elif len(self.files) < len(existed_files):
                logger.error('In force mode, you should include existing table files in yaml configuration file. Otherwise you should drop the previous table before register --force.')
                self.failure_handler.rollback()
                sys.exit(1)
            else:
                for k, f in enumerate(self.files):
                    if f in existed_files:
                        self.files_update.append(self.files[k])
                        self.sizes_update.append(self.sizes[k])
                        self.tupcounts_update.append(self.tupcounts[k])
                        self.eofuncompresseds_update.append(self.eofuncompresseds[k])
                        self.varblockcounts_update.append(self.varblockcounts[k])
                        self.files_append.remove(self.files[k])
                        self.sizes_append.remove(self.sizes[k])
                        self.tupcounts_append.remove(self.tupcounts[k])
                        self.eofuncompresseds_append.remove(self.eofuncompresseds[k])
                        self.varblockcounts_append.remove(self.varblockcounts[k])
                if sorted(self.files_update) != sorted(existed_files):
                    logger.error('In force mode, you should include existing table files in yaml configuration file. Otherwise you should drop the previous table before register --force.')
                    self.failure_handler.rollback()
                    sys.exit(1)

        self._check_files_and_table_in_same_hdfs_cluster(self.filepath, self.tabledir)

        logger.info('New file(s) to be registered: %s' % self.files_append)
        if self.files_update:
            logger.info('Catalog info need to be updated for these files: %s' % self.files_update)

        if self.filesize is not None:
            if len(self.files) != 1:
                logger.error('-e option is only supported with single file case.')
                self.failure_handler.rollback()
                sys.exit(1)
            self.sizes_append = [self.filesize]
            self.sizes = [self.filesize]
        self._check_sizes_valid()

        if self.file_format == 'Parquet':
            self._check_parquet_format(self.files)

    def test_set_move_files_in_hdfs(self):
        ''' Output of print shoud be:
        self.files_update = ['1', '2', '3']
        self.files_same_path = ['5', '6', 'a']
        self.srcfiles=['5', '6', 'a', '1', '2', '3']
        self.dstfiles=['5', '6', '4', '7' , '8', '9']
        '''
        self.firstsegno = 4
        self.files_update = ['1', '2', '3', '5', '6', 'a']
        self.sizes_update = [1, 2, 3, 4, 5, 6]
        self.files_append = ['1', '2', '3']
        self.tupcounts_update = [1, 2, 3, 4, 5, 6]
        self.eofuncompresseds_update = [1, 2, 3, 4, 5, 6]
        self.varblockcounts_update = [1, 2, 3, 4, 5, 6]
        self.tabledir = ''
        self._set_move_files_in_hdfs()
        print self.files_update
        print self.files_same_path
        print self.srcfiles
        print self.dstfiles

    def _check_files_and_table_in_same_hdfs_cluster(self, filepath, tabledir):
        '''Check whether all the files refered by 'filepath' and the location corresponding to the table are in the same hdfs cluster'''
        if not filepath:
            return
        # check whether the files to be registered is in hdfs
        filesystem = filepath.split('://')
        if filesystem[0] != 'hdfs':
            logger.error('Only support registering file(s) in hdfs.')
            self.failure_handler.rollback()
            sys.exit(1)
        fileroot = filepath.split('/')
        tableroot = tabledir.split('/')
        # check the root url of them. eg: for 'hdfs://localhost:8020/temp/tempfile', we check 'hdfs://localohst:8020'
        if fileroot[0:3] != tableroot[0:3]:
            logger.error("Files to be registered and the table are not in the same hdfs cluster.\nFile(s) to be registered: '%s'\nTable path in HDFS: '%s'." % (filepath, tabledir))
            self.failure_handler.rollback()
            sys.exit(1)

    def _get_files_in_hdfs(self, filepath):
        '''Get all the files refered by 'filepath', which could be a file or a directory containing all the files'''
        files, sizes = [], []
        hdfscmd = "hadoop fs -test -e %s" % filepath
        result = local_ssh(hdfscmd, logger)
        if result != 0:
            logger.error("Path '%s' does not exist in hdfs" % filepath)
            self.failure_handler.rollback()
            sys.exit(1)
        hdfscmd = "hadoop fs -ls -R %s" % filepath
        result, out, err = local_ssh_output(hdfscmd)
        outlines = out.splitlines()
        # recursively search all the files under path 'filepath'
        for line in outlines:
            lineargs = line.split()
            if len(lineargs) == 8 and lineargs[0].find("d") == -1:
                files.append(lineargs[7])
                sizes.append(int(lineargs[4]))
        if len(files) == 0 and self.mode == 'usage1':
            logger.info('No files to be registered.')
            logger.info('Hawq Register Succeed.')
            sys.exit(0)
        if len(files) == 0 and self.mode != 'force':
            logger.error("Dir '%s' is empty" % filepath)
            self.failure_handler.rollback()
            sys.exit(1)
        return files, sizes

    def _check_parquet_format(self, files):
        '''Check whether the file to be registered is parquet format'''
        for f in files:
            hdfscmd = 'hadoop fs -du -h %s | head -c1' % f
            rc, out, err = local_ssh_output(hdfscmd)
            if out == '0':
                continue
            hdfscmd = 'hadoop fs -cat %s | head -c4 | grep PAR1' % f
            result1 = local_ssh(hdfscmd)
            hdfscmd = 'hadoop fs -cat %s | tail -c4 | grep PAR1' % f
            result2 = local_ssh(hdfscmd)
            if result1 or result2:
                logger.error('File %s is not parquet format' % f)
                self.failure_handler.rollback()
                sys.exit(1)

    def _set_move_files_in_hdfs(self):
        segno = self.firstsegno
        # set self.files_same_path, self.sizes_same_path and self.segnos_same_path, which are for files existed in HDFS but not in catalog metadata
        update_segno_lst = [f.split('/')[-1] for f in self.files_update]
        catalog_lst = [str(i) for i in range(1, segno)]
        new_catalog_lst = [str(i) for i in range(segno, len(self.files_update) + 1)]
        exist_catalog_lst = []
        tmp_files_update = [f for f in self.files_update]
        tmp_sizes_update = [f for f in self.sizes_update]

        tmp_tupcounts_update = [f for f in self.tupcounts_update]
        tmp_eofuncompresseds_update = [f for f in self.eofuncompresseds_update]
        tmp_varblockcounts_update = [f for f in self.varblockcounts_update]
         
        for k, seg in enumerate(update_segno_lst):
            if seg not in catalog_lst:

                self.files_same_path.append(tmp_files_update[k])
                self.sizes_same_path.append(tmp_sizes_update[k])
                self.tupcounts_same_path.append(tmp_tupcounts_update[k])
                self.eofuncompresseds_same_path.append(tmp_eofuncompresseds_update[k])
                self.varblockcounts_same_path.append(tmp_varblockcounts_update[k])
           
                self.files_update.remove(tmp_files_update[k])
                self.sizes_update.remove(tmp_sizes_update[k])
                self.tupcounts_update.remove(tmp_tupcounts_update[k])
                self.eofuncompresseds_update.remove(tmp_eofuncompresseds_update[k])
                self.varblockcounts_update.remove(tmp_varblockcounts_update[k])

            if seg in new_catalog_lst:
                exist_catalog_lst.append(seg)
        for seg in update_segno_lst:
            if seg not in catalog_lst:
                if seg in exist_catalog_lst:
                    self.segnos_same_path.append(int(seg))
                else:
                    while (str(segno) in exist_catalog_lst):
                        segno += 1
                    self.segnos_same_path.append(segno)
                    segno += 1

        for k, f in enumerate(self.files_same_path):
            self.srcfiles.append(f)
            self.dstfiles.append(self.tabledir + str(self.segnos_same_path[k]))

        segno = self.firstsegno + len(self.files_same_path)
        for f in self.files_append:
            self.srcfiles.append(f)
            self.dstfiles.append(self.tabledir + str(segno))
            segno += 1

    def _move_files_in_hdfs(self):
        '''Move file(s) in src path into the folder correspoding to the target table'''
        for k, srcfile in enumerate(self.srcfiles):
            dstfile = self.dstfiles[k]
            if srcfile != dstfile:
                hdfscmd = 'hadoop fs -mv %s %s' % (srcfile, dstfile)
                sys.stdout.write('hdfscmd: "%s"\n' % hdfscmd)
                result = local_ssh(hdfscmd, logger)
                if result != 0:
                    logger.error('Fail to move %s to %s' % (srcfile, dstfile))
                    self.failure_handler.rollback()
                    sys.exit(1)
                self.failure_handler.commit(('HDFSCMD', hdfscmd))

    def _set_modify_metadata(self, mode):
        segno = self.firstsegno
        append_eofs = self.sizes_append
        update_eofs = self.sizes_update
        same_path_eofs = self.sizes_same_path
        append_tupcounts = self.tupcounts_append
        update_tupcounts = self.tupcounts_update
        same_path_tupcounts = self.tupcounts_same_path
        append_eofuncompresseds = self.eofuncompresseds_append
        update_eofuncompresseds = self.eofuncompresseds_update
        same_path_eofuncompresseds = self.eofuncompresseds_same_path
        append_varblockcounts = self.varblockcounts_append
        update_varblockcounts = self.varblockcounts_update
        same_path_varblockcounts = self.varblockcounts_same_path
        update_segno_lst = [f.split('/')[-1] for f in self.files_update]
        same_path_segno_lst = [seg for seg in self.segnos_same_path]
        relfile_data = self._get_metadata_for_relfile_insert()
        query = ""
        insert_relfile_segs = []
        
        if mode == 'force':
            query += "delete from pg_aoseg.%s;" % (self.seg_name)
        
        if self.file_format == 'Parquet': 
            if len(update_eofs) > 0:
                query += 'insert into pg_aoseg.%s values(%s, %d, %d, %d)' % (self.seg_name, update_segno_lst[0], update_eofs[0], update_tupcounts[0], update_eofuncompresseds[0])
                k = 0
                for update_eof, update_tupcount, update_eofuncompressed in zip(update_eofs[1:], update_tupcounts[1:], update_eofuncompresseds[1:]):
                    query += ',(%s, %d, %d, %d)' % (update_segno_lst[k + 1], update_eof, update_tupcount, update_eofuncompressed)
                    k += 1
                query += ';'
            if len(same_path_eofs) > 0:
                query += 'insert into pg_aoseg.%s values(%d, %d, %d, %d)' % (self.seg_name, same_path_segno_lst[0], same_path_eofs[0], same_path_tupcounts[0], same_path_eofuncompresseds[0])
                insert_relfile_segs.append(int(same_path_segno_lst[0]));
                k = 0
                for same_path_eof, same_path_tupcount, same_path_eofuncompressed in zip(same_path_eofs[1:], same_path_tupcounts[1:], same_path_eofuncompresseds[1:]):
                    query += ',(%d, %d, %d, %d)' % (same_path_segno_lst[k + 1], same_path_eof, same_path_tupcount, same_path_eofuncompressed)
                    insert_relfile_segs.append(int(same_path_segno_lst[k + 1]));
                    k += 1
                query += ';'
            segno += len(same_path_eofs)
            if len(append_eofs) > 0:
                query += 'insert into pg_aoseg.%s values(%d, %d, %d, %d)' % (self.seg_name, segno, append_eofs[0], append_tupcounts[0], append_eofuncompresseds[0])
                insert_relfile_segs.append(segno);
                k = 0
                for append_eof, append_tupcount, append_eofuncompressed in zip(append_eofs[1:], append_tupcounts[1:], append_eofuncompresseds[1:]):
                    query += ',(%d, %d, %d, %d)' % (segno + k + 1, append_eof, append_tupcount, append_eofuncompressed)
                    insert_relfile_segs.append(segno + k + 1);
                    k += 1
                query += ';'
        else:
            if len(update_eofs) > 0:
                query += 'insert into pg_aoseg.%s values(%s, %d, %d, %d, %d)' % (self.seg_name, update_segno_lst[0], update_eofs[0], update_tupcounts[0], update_varblockcounts[0], update_eofuncompresseds[0])
                k = 0
                for update_eof, update_tupcount, update_varblockcount, update_eofuncompresseds in zip(update_eofs[1:], update_tupcounts[1:], update_varblockcounts[1:], update_eofuncompresseds[1:]):
                    query += ',(%s, %d, %d, %d, %d)' % (update_segno_lst[k + 1], update_eof, update_tupcount, update_varblockcount, update_eofuncompresseds)
                    k += 1
                query += ';'
            if len(same_path_eofs) > 0:
                query += 'insert into pg_aoseg.%s values(%d, %d, %d, %d, %d)' % (self.seg_name, same_path_segno_lst[0], same_path_eofs[0], same_path_tupcounts[0], same_path_varblockcounts[0], same_path_eofuncompresseds[0])
                insert_relfile_segs.append(int(same_path_segno_lst[0]));
                k = 0
                for same_path_eof, same_path_tupcount, same_path_varblockcount, same_path_eofuncompressed in zip(same_path_eofs[1:], same_path_tupcounts[1:], same_path_varblockcounts[1:], same_path_eofuncompresseds[1:]):
                    query += ',(%d, %d, %d, %d, %d)' % (same_path_segno_lst[k + 1], same_path_eof, same_path_tupcount, same_path_varblockcount, same_path_eofuncompressed)
                    insert_relfile_segs.append(int(same_path_segno_lst[k + 1]));
                    k += 1
                query += ';'
            segno += len(same_path_eofs)
            if len(append_eofs) > 0:
                query += 'insert into pg_aoseg.%s values(%d, %d, %d, %d, %d)' % (self.seg_name, segno, append_eofs[0], append_tupcounts[0], append_varblockcounts[0], append_eofuncompresseds[0])
                insert_relfile_segs.append(segno);
                k = 0
                for append_eof, append_tupcount, append_varblockcount, append_eofuncompressed in zip(append_eofs[1:], append_tupcounts[1:], append_varblockcounts[1:], append_eofuncompresseds[1:]):
                    query += ',(%d, %d, %d, %d, %d)' % (segno + k + 1, append_eof, append_tupcount, append_varblockcount, append_eofuncompressed)
                    insert_relfile_segs.append(segno + k + 1);
                    k += 1
                query += ';'
        self.queries += query
        for seg in insert_relfile_segs:
            self.queries += "select gp_relfile_insert_for_register(%d, %d, %d, %d, %d, '%s', '%s', '%s', %d);" % (relfile_data[0], relfile_data[1], relfile_data[2], relfile_data[3], seg, relfile_data[4], relfile_data[5], relfile_data[6], relfile_data[7])

    
    def _modify_metadata(self):
        try:
            self.utility_accessor.update_catalog(self.queries)
        except pg.DatabaseError as e:
            print e
            logger.error('In _modify_metadata(): %s' % e);
            self.failure_handler.rollback()
            sys.exit(1)

    def _prepare_register(self):
        if not self.do_not_move:
            self._set_move_files_in_hdfs()
        if self.mode == 'force':
            self._set_modify_metadata('force')
        else:
            self._set_modify_metadata('insert')

    def register(self):
        if not self.do_not_move:
            self._move_files_in_hdfs()
        self._modify_metadata()
        logger.info('Hawq Register Succeed.')


class HawqRegisterPartition(HawqRegister):
    def __init__(self, options, table, utility_conn, conn, failure_handler):
        HawqRegister.__init__(self, options, table, utility_conn, conn, failure_handler)

    def _get_partition_info(self):
        dic = {}
        for ele in self.accessor.get_partition_info(self.dst_table_name):
            dic[ele['partitionboundary']] = ele['partitiontablename']
        return dic

    def _get_partition_parent(self):
        return self.accessor.get_partition_parent(self.dst_table_name)

    def _check_partitionby(self):
        def get_partitionby():
            return self.accessor.get_partitionby(self.dst_table_name)

        if self.partitionby != get_partitionby():
            logger.error('PartitionBy of %s is not consistent with previous partitionby.' % self.tablename)
            self.failure_handler.rollback()
            sys.exit(1)

    def _check_partition_num(self):
        def get_partition_num():
            return self.accessor.get_partition_num(self.dst_table_name)

        if get_partition_num() < len(self.partitions_name):
            logger.error('Partition Number of %s is not consistent with previous partition number.' % self.tablename)
            self.failure_handler.rollback()
            sys.exit(1)

    def _check_duplicate_constraint(self):
        partitions_constraint = sorted(self.partitions_constraint)
        for k, _ in enumerate(partitions_constraint):
            if k < len(partitions_constraint) - 1 and partitions_constraint[k] == partitions_constraint[k+1]:
                logger.error('Partition Constraint "%s" in table %s is duplicated' % (partitions_constraint[k], self.tablename))
                self.failure_handler.rollback()
                sys.exit(1)

    def prepare(self):
        if self.yml:
            self._option_parser_yml(options.yml_config)
        else:
            if self._is_folder(self.filepath) and self.filesize:
                logger.error('-e option is only supported with single file case.')
                sys.exit(1)
            self.file_format = 'Parquet'
            self._check_hash_type() # Usage1 only support randomly distributed table

        # check if it is a multi-level partition table
        partitions_parents = self._get_partition_parent()
        if any(p['parentpartitiontablename'] for p in partitions_parents):
            logger.error('Multi-level partition table is not supported!')
            sys.exit(1)

        parent_tablename = self.tablename
        parent_files = self.files
        parent_sizes = self.sizes
        parent_tupcounts = self.tupcounts
        parent_eofuncompresseds = self.eofuncompresseds
        parent_varblockcounts = self.varblockcounts
        if self.yml:
            self.filepath = self.files[0][:self.files[0].rfind('/')] if self.files else ''
            self._check_file_not_folder()
        for k, pn in enumerate(self.partitions_name):
            self.tablename = pn
            self.files = self.partitions_filepaths[k]
            self.sizes = self.partitions_filesizes[k]
            if self.yml:
                self.filepath = self.files[0][:self.files[0].rfind('/')] if self.files else ''
                self._check_file_not_folder(pn)
        if self.yml:
            self._check_database_encoding()
            if not self._create_table() and self.mode != 'force':
                self.mode = 'usage2_table_exist'
                self._check_partitionby()
                self._check_partition_num()
        partitions = self._get_partition_info()

        self.queries = "set allow_system_table_mods='dml';"
        self.queries += "begin transaction;"
        self._check_duplicate_constraint()
        self.tablename = parent_tablename
        self.files = parent_files
        self.sizes = parent_sizes
        self.tupcounts = parent_tupcounts
        self.eofuncompresseds = parent_eofuncompresseds
        self.varblockcounts = parent_varblockcounts
        self._do_check()
        self._prepare_register()
        schemaname, _ = tablename_handler(self.dst_table_name)
        for k, pn in enumerate(self.partitions_name):
            self.constraint = self.partitions_constraint[k]
            if not partitions.has_key(self.constraint):
                logger.error('Partition Constraint "%s" is not in table %s' % (self.constraint, self.tablename))
                self.failure_handler.rollback()
                sys.exit(1)
            self.tablename = schemaname + '.' + partitions[self.constraint]
            self.files = self.partitions_filepaths[k]
            self.sizes = self.partitions_filesizes[k]
            self.tupcounts = self.partitions_tupcounts[k]
            self.eofuncompresseds = self.partitions_eofuncompresseds[k]
            self.varblockcounts = self.partitions_varblockcounts[k]
            self._do_check()
            self._prepare_register()
        self.queries += "end transaction;"

    def register(self):
        HawqRegister.register(self)

def main(options, args):
    def connectdb(options):
        '''
        Trying to connect database, return a connection object.
        If failed to connect, raise a pg.InternalError
        '''
        url = dbconn.DbURL(hostname=options.host, port=options.port,
                           dbname=options.database, username=options.user)
        logger.info('try to connect database %s:%s %s' % (url.pghost, url.pgport, url.pgdb))
        utility_conn = pg.connect(dbname=url.pgdb, host=url.pghost, port=url.pgport,
                                  user=url.pguser, passwd=url.pgpass, opt='-c gp_session_role=utility')
        conn = pg.connect(dbname=url.pgdb, host=url.pghost, port=url.pgport,
                          user=url.pguser, passwd=url.pgpass)
        return utility_conn, conn

    # connect db
    try:
        utility_conn, conn = connectdb(options)
    except pg.InternalError:
        logger.error('Fail to connect to database, this script can only be run when database is up.')
        return 1

    # check if Function gp_relfile_insert_for_register exists
    sql = 'select count(*) from pg_proc  where proname=\'gp_relfile_insert_for_register\''
    if int(conn.query(sql).dictresult()[0]['count']) == 0 :
        logger.error('Function \'gp_relfile_insert_for_register\' is not found, please run \'hawq upgrade\' then try again.')
        return 1

    failure_handler = FailureHandler(conn)
    # register
    if check_file_exist(options.yml_config) and ispartition(options.yml_config):
        ins = HawqRegisterPartition(options, args[0], utility_conn, conn, failure_handler)
    else:
        ins = HawqRegister(options, args[0], utility_conn, conn, failure_handler)
    ins.prepare()
    ins.register()
    conn.close()

def test(options, args):
    def connectdb(options):
        '''
        Trying to connect database, return a connection object.
        If failed to connect, raise a pg.InternalError
        '''
        url = dbconn.DbURL(hostname=options.host, port=options.port,
                           dbname=options.database, username=options.user)
        logger.info('try to connect database %s:%s %s' % (url.pghost, url.pgport, url.pgdb))
        utility_conn = pg.connect(dbname=url.pgdb, host=url.pghost, port=url.pgport,
                                  user=url.pguser, passwd=url.pgpass, opt='-c gp_session_role=utility')
        conn = pg.connect(dbname=url.pgdb, host=url.pghost, port=url.pgport,
                          user=url.pguser, passwd=url.pgpass)
        return utility_conn, conn

    # connect db
    try:
        utility_conn, conn = connectdb(options)
    except pg.InternalError:
        logger.error('Fail to connect to database, this script can only be run when database is up.')
        return 1

    failure_handler = FailureHandler(conn)
    # register
    ins = HawqRegister(options, args[0], utility_conn, conn, failure_handler)
    ins.test_set_move_files_in_hdfs()


if __name__ == '__main__':
    parser = option_parser()
    options, args = parser.parse_args()
    setup_tool_logging(EXECNAME, getLocalHostname(), getUserName(), logdir=options.logDir)
    if len(args) != 1:
        parser.print_help(sys.stderr)
        sys.exit(1)
    if (options.yml_config or options.force) and options.filepath:
        parser.print_help(sys.stderr)
        sys.exit(1)
    if local_ssh('hdfs'):
        logger.error('Command "hdfs" is not available.')
        sys.exit(1)
    main(options, args)
