import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.utils import k_hop_subgraph, is_undirected, to_undirected
from typing import List, Tuple
import evaluate
import warnings
warnings.filterwarnings('ignore')
# ==================== 超参数配置区 ====================
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
LATENT_DIM = 64 # 生成器内部隐空间维度
HIDDEN_DIM = 256 # 图神经网络隐藏层维度
K_CANDIDATES = 5 # 候选邻居数量(可调超参数)
NUM_EPOCHS_PER_GRAPH = 50 # 每个图的训练轮数
LR_GAN = 1e-3 # 生成器和判别器学习率
EDGE_THRESHOLD = 0.3 # 生成边概率阈值
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
# ==================== 1. Geo-ImGAGN 模型定义 ====================
class GeoGraphGenerator(nn.Module):
"""
生成器:输入少数类中心节点,输出合成节点特征 + 到候选邻居的边权重
核心:学习口袋微环境的拓扑-特征协同模式
"""
def __init__(self, node_dim: int, hidden_dim: int = HIDDEN_DIM, k_candidates: int = K_CANDIDATES):
super().__init__()
# 图编码器:聚合中心及其k阶邻居信息
self.conv1 = GCNConv(node_dim, hidden_dim)
self.conv2 = GCNConv(hidden_dim, hidden_dim)
# 解码器:合成节点特征(ESM2风格)
self.feat_decoder = nn.Sequential(
nn.Linear(hidden_dim * 2, hidden_dim),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, node_dim)
)
# 边权重预测:模拟空间距离(概率越高=空间越近)
self.edge_decoder = nn.Sequential(
nn.Linear(hidden_dim * 2, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, k_candidates),
nn.Sigmoid() # 输出[0,1]概率
)
self.k_candidates = k_candidates
def forward(self, x: torch.Tensor, edge_index: torch.Tensor, center_idx: int) -> Tuple[
torch.Tensor, torch.Tensor, torch.Tensor]:
"""
前向传播
:param x: 节点特征 [N, node_dim]
:param edge_index: 边索引 [2, E]
:param center_idx: 中心节点原始索引
:return: (合成节点特征[1, node_dim], 边权重[1, k], 候选邻居原始索引[k])
"""
# 1. 提取中心节点的1阶子图(物理空间约束)
subset, sub_edge, center_map, _ = k_hop_subgraph(
node_idx=center_idx,
num_hops=1,
edge_index=edge_index,
num_nodes=x.size(0),
relabel_nodes=True
)
# 2. 如果子图太小,用中心节点填充防止崩溃
if subset.size(0) < 2:
subset = torch.tensor([center_idx, center_idx], device=x.device)
sub_edge = torch.tensor([[0, 1], [1, 0]], device=x.device, dtype=torch.long)
center_map = 0
# 3. 图卷积编码结构信息
sub_x = x[subset]
h = F.relu(self.conv1(sub_x, sub_edge))
h = F.relu(self.conv2(h, sub_edge))
# 4. 中心节点特征 + 子图全局 pooling
h_center = h[center_map].squeeze(0) # [hidden]
h_pool = global_mean_pool(h, torch.zeros(h.size(0), dtype=torch.long, device=h.device)).squeeze(0) # [hidden]
# 5. 解码生成
h_cat = torch.cat([h_center.unsqueeze(0), h_pool.unsqueeze(0)], dim=1) # [1, hidden*2]
new_x = self.feat_decoder(h_cat) # [1, node_dim]
# ===== 强制多样化:高斯噪声 + Dropout =====
new_x = new_x + torch.randn_like(new_x, device=new_x.device) * 0.1 # 高斯噪声
new_x = F.dropout(new_x, p=0.2, training=self.training) # dropout
edge_probs = self.edge_decoder(h_cat) # [1, k_candidates]
# 6. 提取候选邻居原始索引(排除中心节点)
candidate_neighbors = subset[subset != center_idx]
# 如果不足k个,重复填充
if len(candidate_neighbors) < self.k_candidates:
repeats = (self.k_candidates // len(candidate_neighbors)) + 1
candidate_neighbors = candidate_neighbors.repeat(repeats)[:self.k_candidates]
return new_x, edge_probs, candidate_neighbors[:self.k_candidates]
class GeoDiscriminator(nn.Module):
"""
判别器:判断子图-节点对是否来自真实数据
输入:子图特征 + 是否为生成节点的标记
"""
def __init__(self, node_dim: int, hidden_dim: int = HIDDEN_DIM):
super().__init__()
# 输入=特征 + 生成标记(1维)
self.conv = GCNConv(node_dim + 1, hidden_dim)
self.mlp = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_dim, 1),
nn.Sigmoid()
)
def forward(self, sub_x: torch.Tensor, sub_edge: torch.Tensor, is_gen_flag: torch.Tensor) -> torch.Tensor:
"""
:param sub_x: 子图节点特征 [N_sub, node_dim]
:param sub_edge: 子图边索引 [2, E_sub]
:param is_gen_flag: 生成标记 [N_sub, 1], 1=生成节点/少数类中心, 0=真实普通节点
:return: 子图为真的概率 [1]
"""
h = torch.cat([sub_x, is_gen_flag], dim=1)
h = F.relu(self.conv(h, sub_edge))
graph_rep = global_mean_pool(h, torch.zeros(h.size(0), dtype=torch.long, device=h.device))
return self.mlp(graph_rep)
# ==================== 2. 训练函数 ====================
def train_generator_on_graph(graph: Data, generator: GeoGraphGenerator,
discriminator: GeoDiscriminator, device: str) -> GeoGraphGenerator:
"""
在单个蛋白质图内部训练生成器和判别器
:param graph: 单个蛋白质图
:param generator: 生成器实例
:param discriminator: 判别器实例
:param device: 计算设备
:return: 训练后的生成器
"""
generator.train()
discriminator.train()
optG = torch.optim.Adam(generator.parameters(), lr=LR_GAN)
optD = torch.optim.Adam(discriminator.parameters(), lr=LR_GAN)
bce_loss = nn.BCELoss()
# 获取正类节点索引
pos_idx = (graph.y == 1).nonzero(as_tuple=True)[0]
if len(pos_idx) == 0:
return generator # 没有正类,无法训练
# 每个epoch随机采样多个中心节点(增加多样性)
for epoch in range(NUM_EPOCHS_PER_GRAPH):
epoch_loss_d, epoch_loss_g = 0.0, 0.0
# 每个epoch训练batch_size个中心节点
batch_size = min(8, len(pos_idx))
centers = pos_idx[torch.randperm(len(pos_idx))[:batch_size]]
for center_idx in centers:
# ---------------- 判别器训练 ----------------
optD.zero_grad()
# 真实子图
# real_subset, real_edge, real_center_map, _ = k_hop_subgraph(
# center_idx.item(), 1, graph.edge_index, graph.x.size(0), relabel_nodes=True
# )
# 真实子图
real_subset, real_edge, real_center_map, _ = k_hop_subgraph(
node_idx=center_idx.item(),
num_hops=1,
edge_index=graph.edge_index,
relabel_nodes=True,
num_nodes=graph.x.size(0)
)
real_sub_x = graph.x[real_subset].to(device)
real_is_gen = torch.zeros(real_sub_x.size(0), 1, device=device)
real_is_gen[real_center_map] = 1.0 # 中心节点标记为1
# 判别器打分
D_real = discriminator(real_sub_x, real_edge.to(device), real_is_gen)
loss_real = bce_loss(D_real, torch.ones_like(D_real))
# 假子图(生成器生成)
with torch.no_grad():
fake_x, fake_probs, neighbor_idx = generator(graph.x, graph.edge_index, center_idx.item())
fake_x = fake_x.to(device)
# 筛选高概率邻居
mask = fake_probs.squeeze() > EDGE_THRESHOLD
selected_neighbors = neighbor_idx[mask.cpu()]
if len(selected_neighbors) == 0: # 没有选中任何邻居,跳过
continue
# 构建假子图
fake_neighbor_x = graph.x[selected_neighbors].to(device)
fake_sub_x = torch.cat([fake_neighbor_x, fake_x], dim=0)
fake_edge = torch.stack([
torch.arange(len(selected_neighbors), dtype=torch.long, device=device),
torch.full((len(selected_neighbors),), len(selected_neighbors), dtype=torch.long, device=device)
], dim=0)
fake_is_gen = torch.zeros(fake_sub_x.size(0), 1, device=device)
fake_is_gen[-1] = 1.0 # 生成节点标记为1
# 判别器对假子图打分
D_fake = discriminator(fake_sub_x, fake_edge, fake_is_gen)
loss_fake = bce_loss(D_fake, torch.zeros_like(D_fake))
# 判别器总损失
loss_d = loss_real + loss_fake
loss_d.backward()
optD.step()
epoch_loss_d += loss_d.item()
# ---------------- 生成器训练 ----------------
optG.zero_grad()
# 重新生成(此时需要梯度)
fake_x, fake_probs, neighbor_idx = generator(graph.x, graph.edge_index, center_idx.item())
fake_x = fake_x.to(device)
# 重新筛选邻居
mask = fake_probs.squeeze() > EDGE_THRESHOLD
selected_neighbors = neighbor_idx[mask.cpu()]
if len(selected_neighbors) == 0:
continue
# 重新构建假子图
fake_neighbor_x = graph.x[selected_neighbors].to(device)
fake_sub_x = torch.cat([fake_neighbor_x, fake_x], dim=0)
fake_edge = torch.stack([
torch.arange(len(selected_neighbors), dtype=torch.long, device=device),
torch.full((len(selected_neighbors),), len(selected_neighbors), dtype=torch.long, device=device)
], dim=0)
fake_is_gen = torch.zeros(fake_sub_x.size(0), 1, device=device)
fake_is_gen[-1] = 1.0
# 生成器损失:希望判别器对假子图打高分
D_fake = discriminator(fake_sub_x, fake_edge, fake_is_gen)
loss_g = bce_loss(D_fake, torch.ones_like(D_fake))
loss_g.backward()
optG.step()
epoch_loss_g += loss_g.item()
# 每10个epoch打印一次
if (epoch + 1) % 10 == 0:
print(f" Epoch {epoch+1}/{NUM_EPOCHS_PER_GRAPH}, "
f"Loss D: {epoch_loss_d/batch_size:.4f}, Loss G: {epoch_loss_g/batch_size:.4f}")
return generator
# ==================== 3. 平衡函数 ====================
def balance_single_graph(graph: Data, generator: GeoGraphGenerator = None,
device: str = DEVICE, k_candidates: int = K_CANDIDATES) -> Data:
"""
对单个蛋白质图进行1:1平衡
:param graph: 原始图
:param generator: 预训练生成器(如果为None则训练新的)
:param device: 计算设备
:param k_candidates: 候选邻居数量
:return: 平衡后的图
"""
# 确保edge_index是无向的
if not is_undirected(graph.edge_index):
graph.edge_index = to_undirected(graph.edge_index)
print(f"警告: 输入图不是无向图,已自动转换")
graph = graph.to(device)
pos_idx = (graph.y == 1).nonzero(as_tuple=True)[0]
neg_idx = (graph.y == 0).nonzero(as_tuple=True)[0]
gap = len(neg_idx) - len(pos_idx)
# 已平衡或负类更少
if gap <= 0:
return graph.cpu()
# 如果没有提供生成器,则训练一个新的
if generator is None:
print(f" 训练新的生成器...")
generator = GeoGraphGenerator(node_dim=graph.x.size(1), k_candidates=k_candidates).to(device)
discriminator = GeoDiscriminator(node_dim=graph.x.size(1)).to(device)
generator = train_generator_on_graph(graph, generator, discriminator, device)
else:
generator.eval()
# 生成节点
new_x_list, new_edge_list = [], []
current_num_nodes = graph.x.size(0)
print(f" 生成 {gap} 个节点...")
with torch.no_grad():
for i in range(gap):
# 随机选择模板中心
center_idx = pos_idx[torch.randint(len(pos_idx), (1,))].item()
# 生成节点特征和边权重
fake_x, edge_probs, neighbor_idx = generator(graph.x, graph.edge_index, center_idx)
# 筛选高概率邻居(模拟空间距离筛选)
mask = edge_probs.squeeze().cpu() > EDGE_THRESHOLD
selected_neighbors = neighbor_idx[mask.cpu()]
if len(selected_neighbors) == 0:
# 如果都没选中,默认连第一个候选邻居
selected_neighbors = neighbor_idx[:1]
# 保存生成节点特征
new_x_list.append(fake_x.cpu())
# 构建新边(无向图需要双向连接)
new_node_id = current_num_nodes + i
# 添加正向边:候选邻居 → 新节点
new_edges_forward = torch.stack([
selected_neighbors,
torch.full_like(selected_neighbors, new_node_id, dtype=torch.long)
], dim=0)
# 添加反向边:新节点 → 候选邻居
new_edges_backward = torch.stack([
torch.full_like(selected_neighbors, new_node_id, dtype=torch.long),
selected_neighbors
], dim=0)
# 合并
new_edges = torch.cat([new_edges_forward, new_edges_backward], dim=1)
new_edge_list.append(new_edges)
# 合并所有新节点和边
all_new_x = torch.cat(new_x_list, dim=0)
new_x = torch.cat([graph.x.cpu(), all_new_x], dim=0)
new_y = torch.cat([graph.y.cpu(), torch.ones(gap, dtype=torch.long)], dim=0)
# 合并边
if new_edge_list:
all_new_edges = torch.cat(new_edge_list, dim=1)
new_edge_index = torch.cat([graph.edge_index.cpu(), all_new_edges.cpu()], dim=1)
# 去重和去自环
new_edge_index = to_undirected(new_edge_index)
else:
new_edge_index = graph.edge_index.cpu()
return Data(x=new_x, edge_index=new_edge_index, y=new_y)
def balance_graph_list(graph_list: List[Data], device: str = DEVICE,
k_candidates: int = K_CANDIDATES) -> List[Data]:
"""
批量处理多个蛋白质图
:param graph_list: 原始图列表
:param device: 计算设备
:param k_candidates: 候选邻居数量
:return: 平衡后的图列表
"""
balanced_list = []
print(f"开始处理 {len(graph_list)} 个蛋白质图...")
for idx, graph in enumerate(graph_list):
print(f"\n[{idx+1}/{len(graph_list)}] 蛋白图 (原始节点: {graph.x.size(0)}, 正类: {graph.y.sum().item()})")
balanced_graph = balance_single_graph(graph, None, device, k_candidates)
# 打印平衡信息
ratio_before = graph.y.float().mean().item()
ratio_after = balanced_graph.y.float().mean().item()
print(f" └─> 平衡后: 节点={balanced_graph.x.size(0)}, 正类={balanced_graph.y.sum().item()}, "
f"正类占比: {ratio_before:.2f} -> {ratio_after:.2f}")
balanced_list.append(balanced_graph)
# 假设列表中只有一个图,平衡一个图看看情况。
break
return balanced_list
# ==================== 5. 主执行流程 ====================
def main():
# 数据路径
DATA_PATH = "ESM/graph_data_573.pt"
OUTPUT_PATH = "GAN/graph_data_573_balanced.pt"
print("="*60)
print("Geo-ImGAGN 蛋白质图非平衡处理")
print("="*60)
print(f"超参配置: 候选邻居={K_CANDIDATES}, 训练轮数={NUM_EPOCHS_PER_GRAPH}, 设备={DEVICE}")
# 1. 加载数据
print("\n[1/4] 加载蛋白质图数据...")
try:
graph_list = torch.load(DATA_PATH, weights_only=False)
print(f" 成功加载 {len(graph_list)} 个蛋白质图")
except Exception as e:
print(f" 错误: 无法加载数据 {e}")
return
# 2. 批量平衡处理
print("\n[2/4] 开始平衡处理...")
balanced_graphs = balance_graph_list(graph_list, device=DEVICE, k_candidates=K_CANDIDATES)
# 3. 保存结果
print("\n[3/4] 保存平衡后的数据...")
torch.save(balanced_graphs, OUTPUT_PATH)
print(f" 已保存至 {OUTPUT_PATH}")
# 4. 验证指标
print("\n[4/4] 计算验证指标...")
metrics = evaluate.evaluate_generation_quality(graph_list, balanced_graphs, device=DEVICE)
# ==================== 最终总结 ====================
print("\n" + "="*60)
print("处理完成!最终总结")
print("="*60)
# 基本信息
total_new_nodes = metrics['summary']['total_generated_nodes']
print(f" 新增合成节点: {total_new_nodes:,} 个")
# 一级指标:空间合理性
coverage_2hop = metrics['pocket_coverage_graph']['2hop_coverage']
print(f" 图跳数2跳覆盖率: {coverage_2hop:.2%} {'✅ 达标' if coverage_2hop > 0.6 else '❌ 未达标'}")
# 二级指标:拓扑一致性
jaccard = metrics['jaccard_similarity']['mean_score']
degree_ratio = metrics['degree_ratio']['ratio']
print(f" Jaccard相似度: {jaccard:.4f} {'✅ 达标' if jaccard > 0.4 else '❌ 未达标'}")
print(f" 度数比率: {degree_ratio:.2f} {'✅ 达标' if 0.5 <= degree_ratio <= 1.5 else '❌ 未达标'}")
# 三级指标:特征与多样性
kl_div = metrics.get('kl_divergence', {}).get('value', 999.0)
cos_sim = metrics['cosine_similarity']['mean']
print(f" KL散度/维度: {kl_div:.4f} {'✅ 良好' if kl_div < 0.2 else '⚠️ 一般'}")
print(f" 余弦相似度: {cos_sim:.4f} {'✅ 良好' if cos_sim < 0.85 else '⚠️ 偏高'}")
# 整体评估与建议
print("\n" + "-"*60)
print("整体评估与建议:")
# 一级指标是否通过
if metrics['pocket_coverage_graph']['status'] == 'PASS':
print(" ✅ 空间合理性: 生成节点大多落在口袋2跳内")
else:
print(" ❌ 空间合理性: 生成节点偏离口袋,建议降低EDGE_THRESHOLD至0.3")
# 二级指标是否通过
topology_ok = (metrics['jaccard_similarity']['status'] == 'PASS' and
metrics['degree_ratio']['status'] == 'PASS')
if topology_ok:
print(" ✅ 拓扑一致性: 连接模式与真实正类匹配")
else:
print(" ⚠️ 拓扑一致性: 建议增加NUM_EPOCHS_PER_GRAPH至100或调整K_CANDIDATES")
# 三级指标
diversity_ok = (metrics.get('kl_divergence', {}).get('status', 'N/A') != 'WARN' and
metrics['cosine_similarity']['status'] != 'WARN')
if diversity_ok:
print(" ✅ 多样性与特征: 模式丰富,无崩溃风险")
else:
print(" ⚠️ 多样性与特征: 建议监控G_loss,防止模式崩溃")
# 最终建议
print("\n" + "最终建议:")
if metrics['pocket_coverage_graph']['status'] == 'PASS' and topology_ok:
print(" 🎉 生成节点质量达标!可直接用于下游GCN训练")
print(" 📌 建议在验证集监控AUPRC,若下降回调EDGE_THRESHOLD")
else:
print(" 🔧 核心指标未达标,建议:")
print(" 1. 降低EDGE_THRESHOLD至0.3")
print(" 2. 增加K_CANDIDATES至5")
print(" 3. 增加NUM_EPOCHS_PER_GRAPH至100")
print("\n代码执行完毕!")
print("="*60)
if __name__ == '__main__':
main()
暂无评论