#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

'''
hawq extract [options] tablename

Options:
    -h hostname: host to connect to
    -p port: port to connect to
    -U username: user to connect to
    -d database: database to connect to
    -o output file: the output metadata file, if not set, will output to terminal.
    -W: force password authentication
    -v: verbose
    -l: sets the directory for log files
    -?: help

hawq extract output YAML file format:

Version: string (xxx.xxx.xxx)
DBVersion: string
FileFormat: string (AO/Parquet)
TableName: string (schemaname.tablename)
DFS_URL: string (hdfs://127.0.0.1:9000)
Encoding: UTF8
AO_Schema:
- name: string
  type: string

AO_FileLocations:
  Blocksize: int
  Checksum: boolean
  CompressionType: string
  CompressionLevel: int
  PartitionBy: string
  Files:
  - path: string (/gpseg0/16385/35469/35470.1)
    size: long

  Partitions:
  - Blocksize: int
    Checksum: boolean
    CompressionType: string
    CompressionLevel: int
    Name: string
    Constraint: string
    Files:
    - path: string
      size: long

Parquet_FileLocations:
  RowGroupSize: long
  PageSize: long
  CompressionType: string
  CompressionLevel: int
  Checksum: boolean
  EnableDictionary: boolean
  PartitionBy: string
  Files:
  - path: string
    size: long
  Partitions:
  - Name: string
    RowGroupSize: long
    PageSize: long
    CompressionType: string
    CompressionLevel: int
    Checksum: boolean
    EnableDictionary: boolean
    Constraint: string
    Files:
    - path: string
      size: long
'''
import os, sys, optparse, getpass, re, urlparse
try:
    from gppylib.commands.unix import getLocalHostname, getUserName
    from gppylib.db import dbconn
    from gppylib.gplog import get_default_logger, setup_tool_logging
    from gppylib.gpparseopts import OptParser, OptChecker
    from pygresql import pg
    from pygresql.pgdb import DatabaseError
    import yaml
except ImportError, e:
    print e
    sys.stderr.write('cannot import module, please check that you have source greenplum_path.sh\n')
    sys.exit(2)


# setup logging
logger = get_default_logger()
EXECNAME = os.path.split(__file__)[-1]

class GpExtractError(Exception): pass


class GpMetadataAccessor:
    def __init__(self, conn):
        self.conn = conn

        rows = self.exec_query("""
        SELECT oid, datname, dat2tablespace,
               pg_encoding_to_char(encoding) encoding
        FROM pg_database WHERE datname=current_database()""")

        self.dbid   = rows[0]['oid']
        self.dbname = rows[0]['datname']
        self.spcid  = rows[0]['dat2tablespace']
        self.dbencoding = rows[0]['encoding']
        self.dbversion = self.exec_query('select version()')[0]['version']

    def exec_query(self, sql):
        return self.conn.query(sql).dictresult()

    def get_segment_locations(self):
        '''
        Return primary segment location list, index by content id.

        Example:
        >>> accessor.get_segment_locations()
        >>> ['hdfs://127.0.0.1:9000/gpseg0', 'hdfs://127.0.0.1:9000/gpseg1']
        '''
        qry = """
        SELECT fselocation
        FROM pg_filespace_entry
        JOIN pg_filespace fs ON fsefsoid=fs.oid
        WHERE fsname='dfs_system';
        """
        rows = self.exec_query(qry)
        return [r['fselocation'] for r in rows]

    def get_seg_name(self, oid):
        '''
        Return relname of pg_aoseg_`oid` table

        Example:
        >>> accessor.get_seg_name(35709)
        >>> [{'relname':'pg_aoseg_35709'}]
        '''
        qry = """
        SELECT pg_class2.relname
        FROM pg_class as pg_class1, pg_appendonly, pg_class as pg_class2
        WHERE pg_class1.oid = %d
        AND pg_class1.oid = pg_appendonly.relid
        AND pg_appendonly.segrelid = pg_class2.oid;
        """ % oid
        return self.exec_query(qry)

    def get_aoseg_files(self, oid):
        '''
        Return rows in pg_aoseg_`oid` table

        Example:
        >>> accessor.get_aoseg_files(35709)
        >>> [{'fileno':'1', 'filesize':'320', 'tupcount':'10', 'varblockcount':'2', eofuncompressed:'320'},
        ...  {'fileno':'2', 'filesize':'880', 'tupcount':'27', 'varblockcount':'3', eofuncompressed:'880'},
        ...  {'fileno':'3', 'filesize':'160', 'tupcount':'5', 'varblockcount':'2', eofuncompressed:'160'}]
        '''
        seg_name = self.get_seg_name(oid)[0]['relname']
        qry = """
        SELECT segno as fileno, eof as filesize, tupcount as tupcount,
        varblockcount as varblockcount, eofuncompressed as eofuncompressed
        FROM pg_aoseg.%s
        ORDER by fileno;
        """ % seg_name
        return self.exec_query(qry)

    def get_paqseg_files(self, oid):
        '''
        Return rows in pg_paqseg_`oid` table, excluding
        rows whose content id is -1.

        Example:
        >>> accessor.get_paqseg_files(35709)
        >>> [{'fileno':'1', 'filesize':'320', 'tupcount':'10', 'eofuncompressed':'320'},
        ...  {'fileno':'2', 'filesize':'880', 'tupcount':'27', 'eofuncompressed':'880'},
        ...  {'fileno':'3', 'filesize':'160', 'tupcount':'5', 'eofuncompressed':'160'}]
        '''
        seg_name = self.get_seg_name(oid)[0]['relname']
        qry = """
        SELECT segno as fileno, eof as filesize, tupcount, eofuncompressed
        FROM pg_aoseg.%s
        ORDER by fileno;
        """ % seg_name
        return self.exec_query(qry)

    def get_pgclass(self, nspname, relname):
        '''
        Return given table's pg_class entry as a dict (with oid).
        '''
        qry = """
        SELECT c.oid, c.*
        FROM pg_class c JOIN pg_namespace n ON c.relnamespace=n.oid
        WHERE n.nspname='%s' and c.relname='%s' and relkind='r'
        """ % (nspname, relname)
        rows = self.exec_query(qry)
        if not rows:
            raise GpExtractError('Table %s.%s not exists!' % (nspname, relname))
        return rows[0]

    def get_schema(self, relid):
        '''
        Fetch schema of the table specified by oid `relid`.
        Return schema as a list of {'name': colname, 'type': coltype} dict.
        '''
        qry = """
        SELECT attname as name, typname as type
        FROM pg_attribute a join pg_type t on a.atttypid = t.oid
        WHERE attrelid=%d and attnum > 0
        ORDER BY attnum asc;
        """ % relid
        return self.exec_query(qry)

    def get_appendonly_attrs(self, relid):
        '''
        Return appendonly table's attributes from pg_appendonly table.
        '''
        qry = """
        SELECT blocksize, pagesize, compresslevel, checksum, compresstype
        FROM pg_appendonly WHERE relid=%d
        """ % relid
        return self.exec_query(qry)[0]

    def get_partitions(self, nspname, relname):
        '''
        Get table's partitions info from pg_partitions view.
        '''
        qry = """
        SELECT partitionschemaname, partitiontablename, partitionname,
               partitiontype, parentpartitiontablename, partitionboundary
        FROM pg_partitions
        WHERE schemaname='%s' and tablename='%s'
        """ % (nspname, relname)
        return self.exec_query(qry)

    def get_partition_columns(self, nspname, relname):
        '''
        Get table's partition columns from pg_partition_columns view.
        '''
        qry = """
        SELECT columnname, partitionlevel
        FROM pg_partition_columns
        WHERE schemaname='%s' and tablename='%s'
        ORDER BY position_in_partition_key
        """ % (nspname, relname)
        return self.exec_query(qry)

    def get_distribution_policy_info(self, oid, relid):
        '''
        Get table's distribution policy from gp_distribution_policy view.
        '''
        qry = """
        SELECT *
        FROM gp_distribution_policy
        WHERE localoid = '%s'
        """ % oid
        policy = self.exec_query(qry)[0]['attrnums']
        if not policy:
            return 'DISTRIBUTED RANDOMLY'
        else:
            cols = [d['name'] for d in self.get_schema(relid)]
            cols_list = [cols[int(k)-1] for k in policy.strip('{}').split(',')]
            return 'DISTRIBUTED BY (' + ','.join(cols_list) + ')'

    def get_bucket_number(self, oid):
        '''
        Get table's bucket number from gp_distribution_policy view.
        '''
        qry = """
        SELECT bucketnum
        FROM gp_distribution_policy
        WHERE localoid = '%s'
        """ % oid
        return self.exec_query(qry)[0]['bucketnum']


def connectdb(options):
    '''
    Trying to connect database, return a connection object.
    If failed to connect, raise a pg.InternalError
    '''
    url = dbconn.DbURL(hostname=options.host, port=options.port,
                       dbname=options.dbname, username=options.user,
                       password=(getpass.getpass() if options.use_getpass else None))

    logger.info('try to connect database %s:%s %s' % (url.pghost, url.pgport, url.pgdb))

    conn = pg.connect(dbname=url.pgdb, host=url.pghost, port=url.pgport,
                      user=url.pguser, passwd=url.pgpass, opt='-c gp_session_role=utility')
    return conn


def extract_metadata(conn, tbname):
    '''
    Return table metadata as a dict for table `tbname`
    '''
    accessor            = GpMetadataAccessor(conn)
    nspname, relname    = get_qualified_tablename(tbname.lower())
    rel_pgclass         = accessor.get_pgclass(nspname, relname)

    segment_localtions = accessor.get_segment_locations()
    DFS_URL = split_segment_location(segment_localtions[0])[0]

    file_format = get_table_format(rel_pgclass['reloptions'])
    logger.info('-- detect FileFormat: %s' % file_format)

    metadata = { 'Version':     '1.0.0',
                 'DBVersion':   accessor.dbversion,
                 'FileFormat':  file_format,
                 'TableName':   '%s.%s' % (nspname, relname),
                 'DFS_URL':     DFS_URL,
                 'Encoding':    accessor.dbencoding }

    partitions          = accessor.get_partitions(nspname, relname)
    partition_columns   = accessor.get_partition_columns(nspname, relname)

    if partitions:
        logger.info('-- detect partitions')
        if any(p['parentpartitiontablename'] for p in partitions):
            raise GpExtractError('Sorry, multi-level partition '
                                 'table is not supported!')

        if len(partition_columns) != 1:
            raise GpExtractError('Sorry, table partitioned by multiple '
                                 'columns is not supported!')

        partitionby = 'PARTITION BY %s (%s)' % (partitions[0]['partitiontype'],
                                                partition_columns[0]['columnname'])


    def get_ao_table_files(oid, relfilenode):
        '''
        Given AO table's oid and relfilenode, return path and size of all its
        data files on HDFS as [{'path': path1, 'size': size1}, {...}].

        Path doesn't include DFS URL.

        Example:
        >>> segment_localtions
        >>> ['hdfs://127.0.0.1:9000/hawq_default', 'hdfs://127.0.0.1:9000/hawq_default']
        >>> tablespace_oid, database_oid, relfilenode, oid
        >>> (16385, 35469, 35470, 35488)
        >>> accessor.get_aoseg_files(35488)
        >>> [{'fileno': '1', 'filesize': '180'},
        ...  {'fileno': '2', 'filesize': '150'},
        ...  {'fileno': '3', 'filesize': '320'}]
        >>> get_ao_table_files(35488, 35470)
        >>> [{'path' :'/hawq_default/16385/35469/1', 'size': 180},
        ...  {'path' :'/hawq_default/16385/35469/2', 'size': 150},
        ...  {'path' :'/hawq_default/16385/35469/3', 'size': 320}]
        '''
        files = []
        for f in accessor.get_aoseg_files(oid):
            segloc = segment_localtions[0]
            segprefix = split_segment_location(segloc)[1]
            path = '%s/%d/%d/%d/%d' % (
                    segprefix,      #dfs segment prefix
                    accessor.spcid, # tablespace oid
                    accessor.dbid,  # database oid
                    relfilenode,
                    f['fileno']
            )
            files.append({'path': path, 'size': int(f['filesize']), 'tupcount': int(f['tupcount']), 'varblockcount': int(f['varblockcount']), 'eofuncompressed': int(f['eofuncompressed'])})
        return files

    def get_parquet_table_files(oid, relfilenode):
        '''
        The same with get_ao_table_files, except that it's for Parquet table.
        '''
        files = []
        for f in accessor.get_paqseg_files(oid):
            segloc = segment_localtions[0]
            segprefix = split_segment_location(segloc)[1]
            path = '%s/%d/%d/%d/%d' % (
                    segprefix,      #dfs segment prefix
                    accessor.spcid, # tablespace oid
                    accessor.dbid,  # database oid
                    relfilenode,
                    f['fileno']
            )
            files.append({'path': path, 'size': int(f['filesize']), 'tupcount': int(f['tupcount']), 'eofuncompressed': int(f['eofuncompressed'])})
        return files

    def extract_AO_metadata():
        relid = rel_pgclass['oid']
        rel_appendonly = accessor.get_appendonly_attrs(relid)

        logger.info('-- extract AO_FileLocations')
        file_locations = {
                'Blocksize':        rel_appendonly['blocksize'],
                'Checksum':         rel_appendonly['checksum'] == 't',
                'CompressionType':  rel_appendonly['compresstype'],
                'CompressionLevel': rel_appendonly['compresslevel'],
                'Files': get_ao_table_files(rel_pgclass['oid'], rel_pgclass['relfilenode'])
        }

        if partitions:
            file_locations['PartitionBy'] = partitionby

            # fill Partitions
            file_locations['Partitions'] = []
            for p in partitions:
                p_pgclass = accessor.get_pgclass(p['partitionschemaname'],
                                                 p['partitiontablename'])

                if get_table_format(p_pgclass['reloptions']) != file_format:
                    raise GpExtractError("table '%s' is not %s" % (
                        p_pgclass['relname'], file_format))

                p_appendonly = accessor.get_appendonly_attrs(p_pgclass['oid'])
                par_info = {
                        'Blocksize':        p_appendonly['blocksize'],
                        'Checksum':         p_appendonly['checksum'] == 't',
                        'CompressionType':  p_appendonly['compresstype'],
                        'CompressionLevel': p_appendonly['compresslevel'],
                        'Name':             p['partitiontablename'],
                        'Constraint':       p['partitionboundary'],
                        'Files': get_ao_table_files(p_pgclass['oid'], p_pgclass['relfilenode'])
                }
                file_locations['Partitions'].append(par_info)
        metadata['AO_FileLocations'] = file_locations
        logger.info('-- extract AO_Schema')
        metadata['AO_Schema'] = accessor.get_schema(relid)
        logger.info('-- extract Distribution_Policy')
        metadata['Distribution_Policy'] = accessor.get_distribution_policy_info(rel_pgclass['oid'], relid)
        logger.info('-- extract bucket number')
        if accessor.get_distribution_policy_info(rel_pgclass['oid'], relid).startswith('DISTRIBUTED BY'):
            metadata['Bucketnum'] = accessor.get_bucket_number(rel_pgclass['oid'])

    def extract_Parquet_metadata():
        relid = rel_pgclass['oid']
        rel_parquet = accessor.get_appendonly_attrs(relid)

        logger.info('-- extract Parquet_FileLocations')
        file_locations = {
                'RowGroupSize':     rel_parquet['blocksize'],
                'PageSize':         rel_parquet['pagesize'],
                'CompressionType':  rel_parquet['compresstype'],
                'CompressionLevel': rel_parquet['compresslevel'],
                'Checksum':         rel_parquet['checksum'] == 't',
                'EnableDictionary': False,
                'Files': get_parquet_table_files(rel_pgclass['oid'], rel_pgclass['relfilenode'])
        }

        if partitions:
            file_locations['PartitionBy'] = partitionby

            # fill Partitions
            file_locations['Partitions'] = []
            for p in partitions:
                p_pgclass = accessor.get_pgclass(p['partitionschemaname'],
                                                 p['partitiontablename'])

                if get_table_format(p_pgclass['reloptions']) != file_format:
                    raise GpExtractError("table '%s' is not %s" % (
                        p_pgclass['relname'], file_format))

                p_parquet = accessor.get_appendonly_attrs(p_pgclass['oid'])
                par_info = {
                        'Name':             p['partitiontablename'],
                        'Constraint':       p['partitionboundary'],
                        'RowGroupSize':     p_parquet['blocksize'],
                        'PageSize':         p_parquet['pagesize'],
                        'CompressionType':  p_parquet['compresstype'],
                        'CompressionLevel': p_parquet['compresslevel'],
                        'Checksum':         p_parquet['checksum'] == 't',
                        'EnableDictionary': False,
                        'Files': get_parquet_table_files(p_pgclass['oid'],
                                                         p_pgclass['relfilenode'])
                }
                file_locations['Partitions'].append(par_info)
        metadata['Parquet_FileLocations'] = file_locations
        logger.info('-- extract Parquet_Schema')
        metadata['Parquet_Schema'] = accessor.get_schema(relid)
        logger.info('-- extract Distribution_Policy')
        metadata['Distribution_Policy'] = accessor.get_distribution_policy_info(rel_pgclass['oid'], relid)
        logger.info('-- extract bucket number')
        if accessor.get_distribution_policy_info(rel_pgclass['oid'], relid).startswith('DISTRIBUTED BY'):
            metadata['Bucketnum'] = accessor.get_bucket_number(rel_pgclass['oid'])


    # extract AO/Parquet specific metadata
    cases = { 'AO': extract_AO_metadata,
              'Parquet': extract_Parquet_metadata }

    cases[file_format]()
    return metadata


def get_qualified_tablename(tbname):
    '''Return (nspname, relname)'''
    parts = tbname.split('.')
    if len(parts) > 2:
        raise GpExtractError('Invalid table name: ' + tbname)

    if len(parts) == 2:
        return parts
    else:
        return ('public', tbname)


def split_segment_location(segloc):
    '''
    Split segment location into two parts, DFS URL and segment prefix.

    >>> split_segment_location('hdfs://127.0.0.1:9000/gpsql/gpseg0')
    >>> ('hdfs://127.0.0.1:9000', '/gpsql/gpseg0')
    '''
    # Python 2.6's urlparse module cannot handle hdfs:// url correctly,
    # here we change scheme to 'http' to hack the urlparse module
    segloc = 'http' + segloc[4:]
    o = urlparse.urlparse(segloc)
    return ('hdfs://%s' % o.netloc,
            o.path)


def get_table_format(reloptions):
    '''
    From table's `reloptions`, return its storage format.
    It should be 'AO' or 'Parquet', otherwise raise GpExtractError
    '''
    if re.search('orientation=column', reloptions):
        raise GpExtractError('Sorry, CO table is not supported.')
    elif re.search('parquet', reloptions):
        return 'Parquet'
    elif re.search('appendonly=true', reloptions):
        return 'AO'
    else:
        raise GpExtractError('Sorry, only AO and Parquet table are supported.')


def create_opt_parser(version):
    parser = OptParser(option_class=OptChecker,
                       usage='usage: %prog [options] tablename',
                       version=version)
    parser.remove_option('-h')
    parser.add_option('-?', '--help', action='help')
    parser.add_option('-h', '--host', help="host of the target DB")
    parser.add_option('-p', '--port', help="port of the target DB", type='int', default=0)
    parser.add_option('-U', '--user', help="user of the target DB")
    parser.add_option('-d', '--dbname', help="target database name")
    parser.add_option('-o', '--output', help="the output metadata file, defaults to stdout", metavar='FILE')
    parser.add_option('-W', action='store_true', dest='use_getpass', help="force password authentication")
    parser.add_option('-v', '--verbose', action='store_true')
    parser.add_option('-l', '--logdir', dest='logDir', help="Sets the directory for log files")
    return parser


def main(args=None):
    parser = create_opt_parser('%prog version $Revision: #1 $')

    options, args = parser.parse_args(args)
    setup_tool_logging(EXECNAME, getLocalHostname(), getUserName(), logdir=options.logDir)
    if len(args) != 1:
        sys.stderr.write('Incorrect number of arguments: missing <tablename>.\n\n')
        parser.print_help(sys.stderr)
        return 1
    if args[0] == 'help':
        parser.print_help(sys.stderr)
        return 1

    # 1. connect db
    try:
        conn = connectdb(options)
    except pg.InternalError:
        logger.error('Failed to connect to database, this script can only be run when the database is up.')
        return 1

    # 2. extract metadata
    logger.info("try to extract metadata of table '%s'" % (args[0]))
    try:
        metadata = extract_metadata(conn, args[0])
    except GpExtractError, e:
        logger.error('Failed to extract metadata: ' + str(e))
        return 1
    finally:
        conn.close()

    # 3. dump to file in YAML format
    try:
        fout = sys.stdout
        if options.output:
            fout = open(options.output, 'w')

        yaml.dump(metadata, stream=fout, default_flow_style=False)
    except IOError, e:
        logger.error(str(e))
        return 1

    if options.output:
        logger.info('metadata has been exported to file %s' % options.output)


if __name__ == '__main__':
    sys.exit(main())
