From 5319c06f0e695b42fae3ae9a669d7336480cbe3f Mon Sep 17 00:00:00 2001 From: fawney19 Date: Mon, 15 Dec 2025 14:29:47 +0800 Subject: [PATCH] migrate: remove model mappings and add aliases support --- ...f0cbf_remove_model_mappings_add_aliases.py | 205 ++++++++++++++++++ 1 file changed, 205 insertions(+) create mode 100644 alembic/versions/20251214_1300_e9b3d63f0cbf_remove_model_mappings_add_aliases.py diff --git a/alembic/versions/20251214_1300_e9b3d63f0cbf_remove_model_mappings_add_aliases.py b/alembic/versions/20251214_1300_e9b3d63f0cbf_remove_model_mappings_add_aliases.py new file mode 100644 index 0000000..f0bba1c --- /dev/null +++ b/alembic/versions/20251214_1300_e9b3d63f0cbf_remove_model_mappings_add_aliases.py @@ -0,0 +1,205 @@ +"""remove_model_mappings_add_aliases + +合并迁移: +1. 添加 provider_model_aliases 字段到 models 表 +2. 迁移 model_mappings 数据到 provider_model_aliases +3. 删除 model_mappings 表 + +Revision ID: e9b3d63f0cbf +Revises: 20251210_baseline +Create Date: 2025-12-14 13:00:22.828183+00:00 + +""" +import json +from datetime import datetime, timezone + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.orm import Session + + +# revision identifiers, used by Alembic. +revision = 'e9b3d63f0cbf' +down_revision = '20251210_baseline' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + """添加 provider_model_aliases 字段,迁移数据,删除 model_mappings 表""" + # 1. 添加 provider_model_aliases 字段 + op.add_column( + 'models', + sa.Column('provider_model_aliases', sa.JSON(), nullable=True) + ) + + # 2. 迁移 model_mappings 数据 + bind = op.get_bind() + session = Session(bind=bind) + + model_mappings_table = sa.table( + "model_mappings", + sa.column("source_model", sa.String), + sa.column("target_global_model_id", sa.String), + sa.column("provider_id", sa.String), + sa.column("mapping_type", sa.String), + sa.column("is_active", sa.Boolean), + ) + + models_table = sa.table( + "models", + sa.column("id", sa.String), + sa.column("provider_id", sa.String), + sa.column("global_model_id", sa.String), + sa.column("provider_model_aliases", sa.JSON), + sa.column("updated_at", sa.DateTime(timezone=True)), + ) + + def normalize_alias_list(value) -> list[dict]: + """将 DB 返回的 JSON 值规范化为 list[{'name': str, 'priority': int}]""" + if value is None: + return [] + + if isinstance(value, str): + try: + value = json.loads(value) if value else [] + except Exception: + return [] + + if not isinstance(value, list): + return [] + + normalized: list[dict] = [] + for item in value: + if not isinstance(item, dict): + continue + + raw_name = item.get("name") + if not isinstance(raw_name, str): + continue + name = raw_name.strip() + if not name: + continue + + raw_priority = item.get("priority", 1) + try: + priority = int(raw_priority) + except Exception: + priority = 1 + if priority < 1: + priority = 1 + + normalized.append({"name": name, "priority": priority}) + + return normalized + + # 查询所有活跃的 provider 级别 alias(只迁移 is_active=True 且 mapping_type='alias' 的) + # 全局别名/映射不迁移(新架构不再支持 source_model -> GlobalModel.name 的解析) + mappings = session.execute( + sa.select( + model_mappings_table.c.source_model, + model_mappings_table.c.target_global_model_id, + model_mappings_table.c.provider_id, + ) + .where( + model_mappings_table.c.is_active.is_(True), + model_mappings_table.c.provider_id.isnot(None), + model_mappings_table.c.mapping_type == "alias", + ) + .order_by(model_mappings_table.c.provider_id, model_mappings_table.c.source_model) + ).all() + + # 按 (provider_id, target_global_model_id) 分组,收集别名 + alias_groups: dict = {} + for source_model, target_global_model_id, provider_id in mappings: + if not isinstance(source_model, str): + continue + source_model = source_model.strip() + if not source_model: + continue + if not isinstance(provider_id, str) or not provider_id: + continue + if not isinstance(target_global_model_id, str) or not target_global_model_id: + continue + + key = (provider_id, target_global_model_id) + if key not in alias_groups: + alias_groups[key] = [] + priority = len(alias_groups[key]) + 1 + alias_groups[key].append({"name": source_model, "priority": priority}) + + # 更新对应的 models 记录 + for (provider_id, global_model_id), aliases in alias_groups.items(): + model_row = session.execute( + sa.select(models_table.c.id, models_table.c.provider_model_aliases) + .where( + models_table.c.provider_id == provider_id, + models_table.c.global_model_id == global_model_id, + ) + .limit(1) + ).first() + + if model_row: + model_id = model_row[0] + existing_aliases = normalize_alias_list(model_row[1]) + + existing_names = {a["name"] for a in existing_aliases} + merged_aliases = list(existing_aliases) + for alias in aliases: + name = alias.get("name") + if not isinstance(name, str): + continue + name = name.strip() + if not name or name in existing_names: + continue + + merged_aliases.append( + { + "name": name, + "priority": len(merged_aliases) + 1, + } + ) + existing_names.add(name) + + session.execute( + models_table.update() + .where(models_table.c.id == model_id) + .values( + provider_model_aliases=merged_aliases if merged_aliases else None, + updated_at=datetime.now(timezone.utc), + ) + ) + + session.commit() + + # 3. 删除 model_mappings 表 + op.drop_table('model_mappings') + + +def downgrade() -> None: + """恢复 model_mappings 表,移除 provider_model_aliases 字段""" + # 1. 恢复 model_mappings 表 + op.create_table( + 'model_mappings', + sa.Column('id', sa.String(36), primary_key=True), + sa.Column('source_model', sa.String(200), nullable=False), + sa.Column( + 'target_global_model_id', + sa.String(36), + sa.ForeignKey('global_models.id', ondelete='CASCADE'), + nullable=False, + ), + sa.Column('provider_id', sa.String(36), sa.ForeignKey('providers.id'), nullable=True), + sa.Column('mapping_type', sa.String(20), nullable=False, server_default='alias'), + sa.Column('is_active', sa.Boolean(), nullable=False, server_default='true'), + sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False), + sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False), + sa.UniqueConstraint('source_model', 'provider_id', name='uq_model_mapping_source_provider'), + ) + op.create_index('ix_model_mappings_source_model', 'model_mappings', ['source_model']) + op.create_index('ix_model_mappings_target_global_model_id', 'model_mappings', ['target_global_model_id']) + op.create_index('ix_model_mappings_provider_id', 'model_mappings', ['provider_id']) + op.create_index('ix_model_mappings_mapping_type', 'model_mappings', ['mapping_type']) + + # 2. 移除 provider_model_aliases 字段 + op.drop_column('models', 'provider_model_aliases')