您当前的位置:首页 > 计算机 > 编程开发 > Python

数据库同步脚本

时间:08-01来源:作者:点击数:

最近公司要换服务,不过旧服务上又有点老量跑着,所以就临时写个脚本同步下老数据库数据。

#!/usr/bin/env python
# -*- coding: utf-8 -*-
import argparse
import logging
import sys
import threading
import time
from datetime import datetime, timedelta
from queue import Queue
from urllib.parse import quote_plus
import pytz
from sqlalchemy import text, create_engine, MetaData, Table, insert, select
from sqlalchemy.engine import Engine, CursorResult, Row
from dataclasses import dataclass
from typing import List, Dict

# 配置日志
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger('idc_db to ali_db')

# 定义时间窗口大小,可根据实际情况调整
TIME_WINDOW = timedelta(minutes=1)

# 定义批量插入大小,可根据实际情况调整
BATCH_INSERT_SIZE = 4096

# 全局变量定义,初始为 None
SOURCE_ENGINE: Engine = None
TARGET_ENGINE: Engine = None

@dataclass
class QueryResult:
    columns: List[str]
    rows: List[Row]

def init_source_pool(db_type, host, port, dbname, user, password, pool_size: int = 5, max_overflow: int = 10):
    """
    根据目标类型和连接参数创建 SQLAlchemy Engine。

    :param db_type: 数据库类型(如 'mysql', 'postgresql')
    :param host: 主机地址
    :param port: 端口号
    :param dbname: 数据库名
    :param user: 用户名
    :param password: 密码
    :param pool_size: 连接池大小
    :param max_overflow: 连接池最大溢出连接数
    :return: SQLAlchemy Engine 实例
    """
    driver_map = {
        'mysql': 'mysql+pymysql',
        'postgresql': 'postgresql+psycopg2',
        'sqlite': 'sqlite',
        # ......
    }
    global SOURCE_ENGINE
    try:
        if db_type not in driver_map:
            raise ValueError(f"暂不支持的数据库类型: {db_type}")

        driver = driver_map[db_type]

        if db_type == 'sqlite':
            db_url = f'{driver}:///{dbname}'
        else:
            encoded_password = quote_plus(password)
            db_url = f'{driver}://{user}:{encoded_password}@{host}:{port}/{dbname}'

        # 创建连接引擎
        SOURCE_ENGINE = create_engine(db_url, pool_size=pool_size, max_overflow=max_overflow, echo=False)
        logger.info('[source] DB初始化成功')
    except Exception as e:
        logger.error(f'[source] DB初始化失败: {str(e)}', exc_info=True)

def init_target_pool(db_type, host, port, dbname, user, password, pool_size: int = 5, max_overflow: int = 10):
    """
    根据目标类型和连接参数创建 SQLAlchemy Engine。

    :param db_type: 数据库类型(如 'mysql', 'postgresql')
    :param host: 主机地址
    :param port: 端口号
    :param dbname: 数据库名
    :param user: 用户名
    :param password: 密码
    :param pool_size: 连接池大小
    :param max_overflow: 连接池最大溢出连接数
    :return: SQLAlchemy Engine 实例
    """
    driver_map = {
        'mysql': 'mysql+pymysql',
        'postgresql': 'postgresql+psycopg2',
        'sqlite': 'sqlite',
        # ......
    }
    global TARGET_ENGINE
    try:
        if db_type not in driver_map:
            raise ValueError(f"暂不支持的数据库类型: {db_type}")

        driver = driver_map[db_type]

        if db_type == 'sqlite':
            db_url = f'{driver}:///{dbname}'
        else:
            encoded_password = quote_plus(password)
            db_url = f'{driver}://{user}:{encoded_password}@{host}:{port}/{dbname}'

        # 创建连接引擎
        TARGET_ENGINE = create_engine(db_url, pool_size=pool_size, max_overflow=max_overflow, echo=False)
        logger.info('[target] DB初始化成功')
    except Exception as e:
        logger.error(f'[target] DB初始化失败: {str(e)}', exc_info=True)

def connect_to_source():
    try:
        conn = SOURCE_ENGINE.connect()
        logger.info('[source] DB连接成功')
        return conn
    except Exception as e:
        logger.error(f'[source] DB连接失败: {str(e)}', exc_info=True)
        return None

def connect_to_target():
    try:
        conn = TARGET_ENGINE.connect()
        logger.info('[target] DB连接成功')
        return conn
    except Exception as e:
        logger.error(f'[target] DB连接失败: {str(e)}', exc_info=True)
        return None

def export_source_data(conn, table_name, start_time, end_time, batch_size=1000) -> QueryResult:
    try:
        # 创建游标对象
        logger.info(f"[写出] 查询时间范围: {start_time} -- {end_time} ")
        # 将时间转换为秒级
        start_dt = datetime.fromtimestamp(start_time / 1000, pytz.timezone('Asia/Shanghai'))
        end_dt = datetime.fromtimestamp(end_time / 1000, pytz.timezone('Asia/Shanghai'))

        lower_timestamp = start_dt.strftime("%Y-%m-%d %H:%M:%S")
        upper_timestamp = end_dt.strftime("%Y-%m-%d %H:%M:%S")
        if "report_event_log_filter" == table_name:
            sql_query = f"SELECT * FROM {table_name} WHERE create_date > '{lower_timestamp}' AND create_date <= '{upper_timestamp}'"
        else:
            sql_query = f"SELECT * FROM {table_name} WHERE create_time > '{lower_timestamp}' AND create_time <= '{upper_timestamp}'"
        logger.info(f"[写出] SQL查询语句: {sql_query}")

        # 执行SQL查询
        try:
            result = conn.execute(sql_query)
            columns = result.keys()
            rows = result.fetchall()
            logger.info(f"[写出] SQL查询成功执行 size: {str(len(rows))}")
        except Exception as e:
            logger.error(f"[写出] SQL查询失败: {str(e)}. 无法解析服务器返回的错误信息", exc_info=True)
            return None
        # 处理查询结果
        try:
            # 清洗与映射
            results = QueryResult(columns=columns, rows=rows)
        except Exception as e:
            logger.error(f"[写出] 获取SQL查询结果失败: {str(e)}", exc_info=True)
            return None
        #
        if not results:
            logger.warning("[写出] 未查询到数据")
            return None

        return results
    except Exception as e:
        logger.error(f"[写出] 导出数据时出错: {str(e)}", exc_info=True)
        return None

def insert_target_data(conn, table_name, data: QueryResult, use_transaction: bool = True):
    try:
        if not data and not data.rows:
            logger.info("[写入] 没有数据需要插入")
            return
        # 生成INSERT语句
        columns = data.columns  # 提取字段名
        placeholders = ', '.join(f':{col}' for col in columns)
        col_names = ', '.join(columns)

        insert_sql = text(f"""INSERT INTO {table_name} ({col_names}) VALUES ({placeholders})""")
        logger.info(f"[写入] SQL插入语句: {insert_sql}")
        # 批量插入,调整批量大小
        logger.info(f"[写入] 开始批量插入数据到 {table_name}")
        row_count = 0
        try:
            # logger.info("insert_target_data")
            values = [dict(row) for row in data.rows]
            for i in range(0, len(values), BATCH_INSERT_SIZE):
                batch = values[i:i + BATCH_INSERT_SIZE]
                if use_transaction:
                    trans = conn.begin()
                    try:
                        conn.execute(insert_sql, batch)
                        trans.commit()
                    except Exception as e:
                        trans.rollback()
                        logger.error(f"[写入] 批量插入失败,批次大小 {len(batch)}: {str(e)}", exc_info=True)
                else:
                    conn.execute(insert_sql, data)
                # 使用批次大小准确统计行数
                row_count += len(batch)
                logger.debug(f"[写入] 成功插入 {len(batch)} 行数据")
        except Exception as e:
            logger.error(f"[写入] 插入数据时发生错误: {str(e)}", exc_info=True)
    except Exception as e:
        logger.error(f"[写入] 插入数据失败: {str(e)}", exc_info=True)

def get_max_createtime(conn, table_name):
    """
    获取ALI MySQL 表中已存在数据的最大createtime
    """
    try:
        if "report_event_log_filter" == table_name:
            sql = f"SELECT MAX(create_date)  AS create_time FROM {table_name}"
        else:
            sql = f"SELECT MAX(create_time)  AS create_time FROM {table_name}"
        result = conn.execute(sql).fetchone()

        if result and result["create_time"] is not None:
            val = result["create_time"]
            if isinstance(val, datetime):
                return int(val.timestamp() * 1000)
            return int(val)
        return 0
    except Exception as e:
        logger.error(f"获取最大create_time失败: {str(e)}", exc_info=True)
        return 0

def data_consumer(queue, conn, table_name):
    while True:
        data = queue.get()
        if data is None:
            break
        try:
            insert_target_data(conn, table_name, data)
            # logger.info("data_consumer")
        except Exception as e:
            logger.error(f"数据消费时发生错误: {str(e)}", exc_info=True)
        queue.task_done()

def main():
    # 解析命令行参数
    parser = argparse.ArgumentParser(description='从IDC MySQL增量导出数据到ALI MySQL')
    parser.add_argument('--source-db-type', default='mysql', help='数据库类型[mysql]')
    parser.add_argument('--source-host', default='192.168.1.130', help='IDC MySQL IP')
    parser.add_argument('--source-port', type=int, default=3306, help='IDC MySQL 端口')
    parser.add_argument('--source-db', default='test', help='IDC MySQL 数据源名称')
    parser.add_argument('--source-table', default='tb_test', help='IDC MySQL 数据表')
    parser.add_argument('--source-user', default='root', help='IDC MySQL 用户名')
    parser.add_argument('--source-password', default='root',  help='IDC MySQL 密码')

    parser.add_argument('--target-db-type', default='mysql', help='数据库类型[mysql]')
    parser.add_argument('--target-host', default='192.168.1.131', help='ALI MySQL IP')
    parser.add_argument('--target-port', type=int, default=3306, help='ALI MySQL 端口')
    parser.add_argument('--target-db', default='test', help='ALI MySQL 数据源名称')
    parser.add_argument('--target-table', default='tb_test', help='ALI MySQL 数据表')
    parser.add_argument('--target-user', default='root', help='ALI MySQL 用户名')
    parser.add_argument('--target-password', default='root', help='ALI MySQL 密码')

    parser.add_argument('--batch-size', type=int, default=1000, help='每批次处理的行数')
    parser.add_argument('--sleep-interval', type=int, default=10, help='处理完数据后休眠的秒数')
    args = parser.parse_args()

    # 初始化连接池
    init_source_pool(args.source_db_type, args.source_host, args.source_port, args.source_db, args.source_user, args.source_password)
    init_target_pool(args.source_db_type, args.target_host, args.target_port, args.target_db, args.target_user, args.target_password)

    # 连接到 ALI MySQL
    source_conn = connect_to_source()
    if not source_conn:
        logger.error("无法连接到 [source db],任务终止")
        sys.exit(1)

    # 连接到 IDC MySQL
    target_conn = connect_to_target()
    if not target_conn:
        logger.error("无法连接到 [target db],任务终止")
        sys.exit(1)

    try:
        # 获取已存在数据的最大eventtime
        start_createtime = get_max_createtime(target_conn, args.target_table)

        # 获取当前北京时间
        beijing_tz = pytz.timezone('Asia/Shanghai')

        # 创建队列和消费者线程
        data_queue = Queue()
        consumer_thread = threading.Thread(target=data_consumer, args=(data_queue, target_conn, args.target_table))
        consumer_thread.start()

        while True:
            current_time = int(datetime.now(beijing_tz).timestamp() * 1000)
            if start_createtime == 0:
                # 若没有取到最大时间,开始时间为当前时间前2分钟
                start_createtime = current_time - 60 * 1000 * 2

            if start_createtime > current_time:
                start_createtime = current_time

            end_createtime = start_createtime + 60 * 1000 * 2
            if end_createtime > current_time:
                end_createtime = current_time

            logger.info(f"start_createtime: {start_createtime}")
            logger.info(f"end_createtime: {end_createtime}")

            # 从MySQL导出数据
            source_data = export_source_data(source_conn, args.source_table, start_createtime, end_createtime, args.batch_size)
            if not source_data and not source_data.rows:
                logger.info("没有数据需要导入,进入休眠阶段")
                time.sleep(args.sleep_interval)
                continue

            # 将数据放入队列
            data_queue.put(source_data)

            # 获取本次处理数据的最大createtime
            if source_data:
                if "report_event_log_filter" == args.source_table:
                    max_processed_time = max(int(row['create_date'].timestamp() * 1000) for row in source_data.rows)
                else:
                    max_processed_time = max(int(row['create_time'].timestamp() * 1000) for row in source_data.rows)
                start_createtime = max_processed_time
                logger.info(f"last_createtime: {start_createtime}")
                logger.info(f"---------------------------------------")
    except KeyboardInterrupt:
        logger.info("接收到中断信号,任务终止")
        data_queue.put(None)  # 发送终止信号给消费者线程
        consumer_thread.join()  # 等待消费者线程结束
    finally:
        target_conn.close()
        source_conn.close()
        logger.info("连接已关闭")

if __name__ == "__main__":
    main()

 

方便获取更多学习、工作、生活信息请关注本站微信公众号城东书院 微信服务号城东书院 微信订阅号
推荐内容
相关内容
栏目更新
栏目热门
本栏推荐