get_balance_graph_data.py(待完成)
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.utils import k_hop_subgraph, is_undirected, to_undirected
from typing import List, Tuple
import evaluate
import warnings
warnings.filterwarnings('ignore')

# ==================== 超参数配置区 ====================
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
LATENT_DIM = 64  # 生成器内部隐空间维度
HIDDEN_DIM = 256  # 图神经网络隐藏层维度
K_CANDIDATES = 5  # 候选邻居数量(可调超参数)
NUM_EPOCHS_PER_GRAPH = 50  # 每个图的训练轮数
LR_GAN = 1e-3  # 生成器和判别器学习率
EDGE_THRESHOLD = 0.3  # 生成边概率阈值
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)


# ==================== 1. Geo-ImGAGN 模型定义 ====================

class GeoGraphGenerator(nn.Module):
    """
    生成器:输入少数类中心节点,输出合成节点特征 + 到候选邻居的边权重
    核心:学习口袋微环境的拓扑-特征协同模式
    """
    def __init__(self, node_dim: int, hidden_dim: int = HIDDEN_DIM, k_candidates: int = K_CANDIDATES):
        super().__init__()
        # 图编码器:聚合中心及其k阶邻居信息
        self.conv1 = GCNConv(node_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        
        # 解码器:合成节点特征(ESM2风格)
        self.feat_decoder = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, node_dim)
        )
        
        # 边权重预测:模拟空间距离(概率越高=空间越近)
        self.edge_decoder = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, k_candidates),
            nn.Sigmoid()  # 输出[0,1]概率
        )
        
        self.k_candidates = k_candidates

    def forward(self, x: torch.Tensor, edge_index: torch.Tensor, center_idx: int) -> Tuple[
        torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        前向传播
        :param x: 节点特征 [N, node_dim]
        :param edge_index: 边索引 [2, E]
        :param center_idx: 中心节点原始索引
        :return: (合成节点特征[1, node_dim], 边权重[1, k], 候选邻居原始索引[k])
        """
        # 1. 提取中心节点的1阶子图(物理空间约束)
        subset, sub_edge, center_map, _ = k_hop_subgraph(
            node_idx=center_idx,
            num_hops=1,
            edge_index=edge_index,
            num_nodes=x.size(0),
            relabel_nodes=True
        )
        
        # 2. 如果子图太小,用中心节点填充防止崩溃
        if subset.size(0) < 2:
            subset = torch.tensor([center_idx, center_idx], device=x.device)
            sub_edge = torch.tensor([[0, 1], [1, 0]], device=x.device, dtype=torch.long)
            center_map = 0
        
        # 3. 图卷积编码结构信息
        sub_x = x[subset]
        h = F.relu(self.conv1(sub_x, sub_edge))
        h = F.relu(self.conv2(h, sub_edge))
        
        # 4. 中心节点特征 + 子图全局 pooling
        h_center = h[center_map].squeeze(0)  # [hidden]
        h_pool = global_mean_pool(h, torch.zeros(h.size(0), dtype=torch.long, device=h.device)).squeeze(0)  # [hidden]
        
        # 5. 解码生成
        h_cat = torch.cat([h_center.unsqueeze(0), h_pool.unsqueeze(0)], dim=1)  # [1, hidden*2]
        new_x = self.feat_decoder(h_cat)  # [1, node_dim]

        # ===== 强制多样化:高斯噪声 + Dropout =====
        new_x = new_x + torch.randn_like(new_x, device=new_x.device) * 0.1  # 高斯噪声
        new_x = F.dropout(new_x, p=0.2, training=self.training)            # dropout

        edge_probs = self.edge_decoder(h_cat)  # [1, k_candidates]
        
        # 6. 提取候选邻居原始索引(排除中心节点)
        candidate_neighbors = subset[subset != center_idx]
        
        # 如果不足k个,重复填充
        if len(candidate_neighbors) < self.k_candidates:
            repeats = (self.k_candidates // len(candidate_neighbors)) + 1
            candidate_neighbors = candidate_neighbors.repeat(repeats)[:self.k_candidates]
        
        return new_x, edge_probs, candidate_neighbors[:self.k_candidates]


class GeoDiscriminator(nn.Module):
    """
    判别器:判断子图-节点对是否来自真实数据
    输入:子图特征 + 是否为生成节点的标记
    """
    def __init__(self, node_dim: int, hidden_dim: int = HIDDEN_DIM):
        super().__init__()
        # 输入=特征 + 生成标记(1维)
        self.conv = GCNConv(node_dim + 1, hidden_dim)
        self.mlp = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )

    def forward(self, sub_x: torch.Tensor, sub_edge: torch.Tensor, is_gen_flag: torch.Tensor) -> torch.Tensor:
        """
        :param sub_x: 子图节点特征 [N_sub, node_dim]
        :param sub_edge: 子图边索引 [2, E_sub]
        :param is_gen_flag: 生成标记 [N_sub, 1], 1=生成节点/少数类中心, 0=真实普通节点
        :return: 子图为真的概率 [1]
        """
        h = torch.cat([sub_x, is_gen_flag], dim=1)
        h = F.relu(self.conv(h, sub_edge))
        graph_rep = global_mean_pool(h, torch.zeros(h.size(0), dtype=torch.long, device=h.device))
        return self.mlp(graph_rep)


# ==================== 2. 训练函数 ====================

def train_generator_on_graph(graph: Data, generator: GeoGraphGenerator, 
                             discriminator: GeoDiscriminator, device: str) -> GeoGraphGenerator:
    """
    在单个蛋白质图内部训练生成器和判别器
    :param graph: 单个蛋白质图
    :param generator: 生成器实例
    :param discriminator: 判别器实例
    :param device: 计算设备
    :return: 训练后的生成器
    """
    generator.train()
    discriminator.train()
    
    optG = torch.optim.Adam(generator.parameters(), lr=LR_GAN)
    optD = torch.optim.Adam(discriminator.parameters(), lr=LR_GAN)
    bce_loss = nn.BCELoss()
    
    # 获取正类节点索引
    pos_idx = (graph.y == 1).nonzero(as_tuple=True)[0]
    if len(pos_idx) == 0:
        return generator  # 没有正类,无法训练
    
    # 每个epoch随机采样多个中心节点(增加多样性)
    for epoch in range(NUM_EPOCHS_PER_GRAPH):
        epoch_loss_d, epoch_loss_g = 0.0, 0.0
        
        # 每个epoch训练batch_size个中心节点
        batch_size = min(8, len(pos_idx))
        centers = pos_idx[torch.randperm(len(pos_idx))[:batch_size]]
        
        for center_idx in centers:
            # ---------------- 判别器训练 ----------------
            optD.zero_grad()
            
            # 真实子图
            # real_subset, real_edge, real_center_map, _ = k_hop_subgraph(
            #     center_idx.item(), 1, graph.edge_index, graph.x.size(0), relabel_nodes=True
            # )
            # 真实子图
            real_subset, real_edge, real_center_map, _ = k_hop_subgraph(
            node_idx=center_idx.item(),
            num_hops=1,
            edge_index=graph.edge_index,
            relabel_nodes=True,
            num_nodes=graph.x.size(0)
            )


            real_sub_x = graph.x[real_subset].to(device)
            real_is_gen = torch.zeros(real_sub_x.size(0), 1, device=device)
            real_is_gen[real_center_map] = 1.0  # 中心节点标记为1
            
            # 判别器打分
            D_real = discriminator(real_sub_x, real_edge.to(device), real_is_gen)
            loss_real = bce_loss(D_real, torch.ones_like(D_real))
            
            # 假子图(生成器生成)
            with torch.no_grad():
                fake_x, fake_probs, neighbor_idx = generator(graph.x, graph.edge_index, center_idx.item())
                fake_x = fake_x.to(device)
            
            # 筛选高概率邻居
            mask = fake_probs.squeeze() > EDGE_THRESHOLD
            selected_neighbors = neighbor_idx[mask.cpu()]
            
            if len(selected_neighbors) == 0:  # 没有选中任何邻居,跳过
                continue
            
            # 构建假子图
            fake_neighbor_x = graph.x[selected_neighbors].to(device)
            fake_sub_x = torch.cat([fake_neighbor_x, fake_x], dim=0)
            fake_edge = torch.stack([
                torch.arange(len(selected_neighbors), dtype=torch.long, device=device),
                torch.full((len(selected_neighbors),), len(selected_neighbors), dtype=torch.long, device=device)
            ], dim=0)
            fake_is_gen = torch.zeros(fake_sub_x.size(0), 1, device=device)
            fake_is_gen[-1] = 1.0  # 生成节点标记为1
            
            # 判别器对假子图打分
            D_fake = discriminator(fake_sub_x, fake_edge, fake_is_gen)
            loss_fake = bce_loss(D_fake, torch.zeros_like(D_fake))
            
            # 判别器总损失
            loss_d = loss_real + loss_fake
            loss_d.backward()
            optD.step()
            epoch_loss_d += loss_d.item()
            
            # ---------------- 生成器训练 ----------------
            optG.zero_grad()
            
            # 重新生成(此时需要梯度)
            fake_x, fake_probs, neighbor_idx = generator(graph.x, graph.edge_index, center_idx.item())
            fake_x = fake_x.to(device)
            
            # 重新筛选邻居
            mask = fake_probs.squeeze() > EDGE_THRESHOLD
            selected_neighbors = neighbor_idx[mask.cpu()]
            
            if len(selected_neighbors) == 0:
                continue
            
            # 重新构建假子图
            fake_neighbor_x = graph.x[selected_neighbors].to(device)
            fake_sub_x = torch.cat([fake_neighbor_x, fake_x], dim=0)
            fake_edge = torch.stack([
                torch.arange(len(selected_neighbors), dtype=torch.long, device=device),
                torch.full((len(selected_neighbors),), len(selected_neighbors), dtype=torch.long, device=device)
            ], dim=0)
            fake_is_gen = torch.zeros(fake_sub_x.size(0), 1, device=device)
            fake_is_gen[-1] = 1.0
            
            # 生成器损失:希望判别器对假子图打高分
            D_fake = discriminator(fake_sub_x, fake_edge, fake_is_gen)
            loss_g = bce_loss(D_fake, torch.ones_like(D_fake))
            
            loss_g.backward()
            optG.step()
            epoch_loss_g += loss_g.item()
        
        # 每10个epoch打印一次
        if (epoch + 1) % 10 == 0:
            print(f"    Epoch {epoch+1}/{NUM_EPOCHS_PER_GRAPH}, "
                  f"Loss D: {epoch_loss_d/batch_size:.4f}, Loss G: {epoch_loss_g/batch_size:.4f}")
    
    return generator

# ==================== 3. 平衡函数 ====================

def balance_single_graph(graph: Data, generator: GeoGraphGenerator = None, 
                         device: str = DEVICE, k_candidates: int = K_CANDIDATES) -> Data:
    """
    对单个蛋白质图进行1:1平衡
    :param graph: 原始图
    :param generator: 预训练生成器(如果为None则训练新的)
    :param device: 计算设备
    :param k_candidates: 候选邻居数量
    :return: 平衡后的图
    """
    # 确保edge_index是无向的
    if not is_undirected(graph.edge_index):
        graph.edge_index = to_undirected(graph.edge_index)
        print(f"警告: 输入图不是无向图,已自动转换")
    
    graph = graph.to(device)
    pos_idx = (graph.y == 1).nonzero(as_tuple=True)[0]
    neg_idx = (graph.y == 0).nonzero(as_tuple=True)[0]
    gap = len(neg_idx) - len(pos_idx)
    
    # 已平衡或负类更少
    if gap <= 0:
        return graph.cpu()
    
    # 如果没有提供生成器,则训练一个新的
    if generator is None:
        print(f"  训练新的生成器...")
        generator = GeoGraphGenerator(node_dim=graph.x.size(1), k_candidates=k_candidates).to(device)
        discriminator = GeoDiscriminator(node_dim=graph.x.size(1)).to(device)
        generator = train_generator_on_graph(graph, generator, discriminator, device)
    else:
        generator.eval()
    
    # 生成节点
    new_x_list, new_edge_list = [], []
    current_num_nodes = graph.x.size(0)
    
    print(f"  生成 {gap} 个节点...")
    with torch.no_grad():
        for i in range(gap):
            # 随机选择模板中心
            center_idx = pos_idx[torch.randint(len(pos_idx), (1,))].item()
            
            # 生成节点特征和边权重
            fake_x, edge_probs, neighbor_idx = generator(graph.x, graph.edge_index, center_idx)
            
            # 筛选高概率邻居(模拟空间距离筛选)
            mask = edge_probs.squeeze().cpu() > EDGE_THRESHOLD
            selected_neighbors = neighbor_idx[mask.cpu()]
            
            if len(selected_neighbors) == 0:
                # 如果都没选中,默认连第一个候选邻居
                selected_neighbors = neighbor_idx[:1]
            
            # 保存生成节点特征
            new_x_list.append(fake_x.cpu())
            
            # 构建新边(无向图需要双向连接)
            new_node_id = current_num_nodes + i
            # 添加正向边:候选邻居 → 新节点
            new_edges_forward = torch.stack([
                selected_neighbors,
                torch.full_like(selected_neighbors, new_node_id, dtype=torch.long)
            ], dim=0)
            # 添加反向边:新节点 → 候选邻居
            new_edges_backward = torch.stack([
                torch.full_like(selected_neighbors, new_node_id, dtype=torch.long),
                selected_neighbors
            ], dim=0)
            # 合并
            new_edges = torch.cat([new_edges_forward, new_edges_backward], dim=1)
            new_edge_list.append(new_edges)
    
    # 合并所有新节点和边
    all_new_x = torch.cat(new_x_list, dim=0)
    new_x = torch.cat([graph.x.cpu(), all_new_x], dim=0)
    new_y = torch.cat([graph.y.cpu(), torch.ones(gap, dtype=torch.long)], dim=0)
    
    # 合并边
    if new_edge_list:
        all_new_edges = torch.cat(new_edge_list, dim=1)
        new_edge_index = torch.cat([graph.edge_index.cpu(), all_new_edges.cpu()], dim=1)
        # 去重和去自环
        new_edge_index = to_undirected(new_edge_index)
    else:
        new_edge_index = graph.edge_index.cpu()
    
    return Data(x=new_x, edge_index=new_edge_index, y=new_y)


def balance_graph_list(graph_list: List[Data], device: str = DEVICE, 
                       k_candidates: int = K_CANDIDATES) -> List[Data]:
    """
    批量处理多个蛋白质图
    :param graph_list: 原始图列表
    :param device: 计算设备
    :param k_candidates: 候选邻居数量
    :return: 平衡后的图列表
    """
    balanced_list = []
    print(f"开始处理 {len(graph_list)} 个蛋白质图...")
    
    for idx, graph in enumerate(graph_list):
        print(f"\n[{idx+1}/{len(graph_list)}] 蛋白图 (原始节点: {graph.x.size(0)}, 正类: {graph.y.sum().item()})")
        balanced_graph = balance_single_graph(graph, None, device, k_candidates)
        
        # 打印平衡信息
        ratio_before = graph.y.float().mean().item()
        ratio_after = balanced_graph.y.float().mean().item()
        print(f"  └─> 平衡后: 节点={balanced_graph.x.size(0)}, 正类={balanced_graph.y.sum().item()}, "
              f"正类占比: {ratio_before:.2f} -> {ratio_after:.2f}")
        
        balanced_list.append(balanced_graph)
        # 假设列表中只有一个图,平衡一个图看看情况。
        break
    
    return balanced_list

# ==================== 5. 主执行流程 ====================

def main():
    # 数据路径
    DATA_PATH = "ESM/graph_data_573.pt"
    OUTPUT_PATH = "GAN/graph_data_573_balanced.pt"
    
    print("="*60)
    print("Geo-ImGAGN 蛋白质图非平衡处理")
    print("="*60)
    print(f"超参配置: 候选邻居={K_CANDIDATES}, 训练轮数={NUM_EPOCHS_PER_GRAPH}, 设备={DEVICE}")
    
    # 1. 加载数据
    print("\n[1/4] 加载蛋白质图数据...")
    try:
        graph_list = torch.load(DATA_PATH, weights_only=False)
        print(f"  成功加载 {len(graph_list)} 个蛋白质图")
    except Exception as e:
        print(f"  错误: 无法加载数据 {e}")
        return
    
    # 2. 批量平衡处理
    print("\n[2/4] 开始平衡处理...")
    balanced_graphs = balance_graph_list(graph_list, device=DEVICE, k_candidates=K_CANDIDATES)
    
    # 3. 保存结果
    print("\n[3/4] 保存平衡后的数据...")
    torch.save(balanced_graphs, OUTPUT_PATH)
    print(f"  已保存至 {OUTPUT_PATH}")
    
    # 4. 验证指标
    print("\n[4/4] 计算验证指标...")
    metrics = evaluate.evaluate_generation_quality(graph_list, balanced_graphs, device=DEVICE)
    
   # ==================== 最终总结 ====================
    print("\n" + "="*60)
    print("处理完成!最终总结")
    print("="*60)
    
    # 基本信息
    total_new_nodes = metrics['summary']['total_generated_nodes']
    print(f"  新增合成节点: {total_new_nodes:,} 个")
    
    # 一级指标:空间合理性
    coverage_2hop = metrics['pocket_coverage_graph']['2hop_coverage']
    print(f"  图跳数2跳覆盖率: {coverage_2hop:.2%}  {'✅ 达标' if coverage_2hop > 0.6 else '❌ 未达标'}")
    
    # 二级指标:拓扑一致性
    jaccard = metrics['jaccard_similarity']['mean_score']
    degree_ratio = metrics['degree_ratio']['ratio']
    print(f"  Jaccard相似度: {jaccard:.4f}  {'✅ 达标' if jaccard > 0.4 else '❌ 未达标'}")
    print(f"  度数比率: {degree_ratio:.2f}  {'✅ 达标' if 0.5 <= degree_ratio <= 1.5 else '❌ 未达标'}")
    
    # 三级指标:特征与多样性
    kl_div = metrics.get('kl_divergence', {}).get('value', 999.0)
    cos_sim = metrics['cosine_similarity']['mean']
    print(f"  KL散度/维度: {kl_div:.4f}  {'✅ 良好' if kl_div < 0.2 else '⚠️ 一般'}")
    print(f"  余弦相似度: {cos_sim:.4f}  {'✅ 良好' if cos_sim < 0.85 else '⚠️ 偏高'}")
    
    # 整体评估与建议
    print("\n" + "-"*60)
    print("整体评估与建议:")
    
    # 一级指标是否通过
    if metrics['pocket_coverage_graph']['status'] == 'PASS':
        print("  ✅ 空间合理性: 生成节点大多落在口袋2跳内")
    else:
        print("  ❌ 空间合理性: 生成节点偏离口袋,建议降低EDGE_THRESHOLD至0.3")
    
    # 二级指标是否通过
    topology_ok = (metrics['jaccard_similarity']['status'] == 'PASS' and 
                   metrics['degree_ratio']['status'] == 'PASS')
    if topology_ok:
        print("  ✅ 拓扑一致性: 连接模式与真实正类匹配")
    else:
        print("  ⚠️ 拓扑一致性: 建议增加NUM_EPOCHS_PER_GRAPH至100或调整K_CANDIDATES")
    
    # 三级指标
    diversity_ok = (metrics.get('kl_divergence', {}).get('status', 'N/A') != 'WARN' and 
                    metrics['cosine_similarity']['status'] != 'WARN')
    if diversity_ok:
        print("  ✅ 多样性与特征: 模式丰富,无崩溃风险")
    else:
        print("  ⚠️ 多样性与特征: 建议监控G_loss,防止模式崩溃")
    
    # 最终建议
    print("\n" + "最终建议:")
    if metrics['pocket_coverage_graph']['status'] == 'PASS' and topology_ok:
        print("  🎉 生成节点质量达标!可直接用于下游GCN训练")
        print("  📌 建议在验证集监控AUPRC,若下降回调EDGE_THRESHOLD")
    else:
        print("  🔧 核心指标未达标,建议:")
        print("     1. 降低EDGE_THRESHOLD至0.3")
        print("     2. 增加K_CANDIDATES至5")
        print("     3. 增加NUM_EPOCHS_PER_GRAPH至100")
    
    print("\n代码执行完毕!")
    print("="*60)


if __name__ == '__main__':
    main()
暂无评论

发送评论 编辑评论


				
|´・ω・)ノ
ヾ(≧∇≦*)ゝ
(☆ω☆)
(╯‵□′)╯︵┴─┴
 ̄﹃ ̄
(/ω\)
∠( ᐛ 」∠)_
(๑•̀ㅁ•́ฅ)
→_→
୧(๑•̀⌄•́๑)૭
٩(ˊᗜˋ*)و
(ノ°ο°)ノ
(´இ皿இ`)
⌇●﹏●⌇
(ฅ´ω`ฅ)
(╯°A°)╯︵○○○
φ( ̄∇ ̄o)
ヾ(´・ ・`。)ノ"
( ง ᵒ̌皿ᵒ̌)ง⁼³₌₃
(ó﹏ò。)
Σ(っ °Д °;)っ
( ,,´・ω・)ノ"(´っω・`。)
╮(╯▽╰)╭
o(*////▽////*)q
>﹏<
( ๑´•ω•) "(ㆆᴗㆆ)
😂
😀
😅
😊
🙂
🙃
😌
😍
😘
😜
😝
😏
😒
🙄
😳
😡
😔
😫
😱
😭
💩
👻
🙌
🖕
👍
👫
👬
👭
🌚
🌝
🙈
💊
😶
🙏
🍦
🍉
😣
Source: github.com/k4yt3x/flowerhd
颜文字
Emoji
小恐龙
花!
上一篇