import warnings
import torch
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from torch_geometric.data import Data
import random
# 配置警告和随机种子
warnings.filterwarnings("ignore")
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
# 设备配置
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"使用设备:{device}")
# 导入模型和工具函数
from models import Generator, GCN
from utils import euclidean_dist, accuracy
class GraphBalanceConfig:
"""配置类 - 集中管理所有超参数"""
def __init__(self):
# 模型参数
self.generator_input_dim = 300 # 生成器输入噪声维度
self.hidden_dim = 1280 # GCN隐藏层维度
self.num_classes = 2 # 分类类别数
self.dropout_rate = 0.1 # Dropout率
# 训练参数
self.learning_rate = 0.0001 # 学习率
self.weight_decay = 0.0005 # 权重衰减
self.num_epochs = 300 # 每个图的训练轮数
self.num_iterations = 10 # 生成器迭代次数
# 数据参数
self.threshold = 0.038 # 邻接矩阵阈值
self.train_ratio = 0.6 # 训练集比例
self.val_ratio = 0.2 # 验证集比例
self.test_ratio = 0.2 # 测试集比例
self.min_positive_samples = 5 # 最小正类样本数
# 路径配置
self.input_path = './ImGAGN/573.pt'
self.train_output_path = './ImGAGN/573_train.pt'
self.val_output_path = './ImGAGN/573_val.pt'
class GraphDataAnalyzer:
"""图数据分析器 - 负责数据加载、分析和预处理"""
def __init__(self, config):
self.config = config
self.original_data = None
self.train_data = None
self.val_data = None
def load_data(self, file_path):
"""加载图数据文件"""
try:
print(f"正在加载数据文件: {file_path}")
self.original_data = torch.load(file_path, weights_only=False, map_location=device)
print(f"成功加载 {len(self.original_data)} 个图数据对象")
return True
except Exception as e:
print(f"加载数据失败: {e}")
return False
def analyze_graph_balance(self, graph_data):
"""分析单个图的类别平衡情况"""
labels = graph_data.y
num_negatives = torch.sum(labels == 0).item()
num_positives = torch.sum(labels == 1).item()
imbalance_gap = abs(num_negatives - num_positives)
return {
'num_nodes': graph_data.num_nodes,
'num_negatives': num_negatives,
'num_positives': num_positives,
'imbalance_gap': imbalance_gap,
'positive_ratio': num_positives / graph_data.num_nodes if graph_data.num_nodes > 0 else 0
}
def split_data_by_balance(self):
"""根据正类样本数量分割数据"""
if not self.original_data:
raise ValueError("请先加载数据")
well_balanced_graphs = [] # 正类样本 >= 5的图
sparse_positive_graphs = [] # 正类样本 < 5的图
for graph in self.original_data:
analysis = self.analyze_graph_balance(graph)
if analysis['num_positives'] >= self.config.min_positive_samples:
well_balanced_graphs.append(graph)
else:
sparse_positive_graphs.append(graph)
print(f"数据分割结果:")
print(f" - 正类丰富的图: {len(well_balanced_graphs)} 个")
print(f" - 正类稀疏的图: {len(sparse_positive_graphs)} 个")
return well_balanced_graphs, sparse_positive_graphs
def create_train_test_split(self, well_balanced_graphs, sparse_positive_graphs):
"""创建训练测试分割"""
# 打乱顺序
random.shuffle(well_balanced_graphs)
# 按照7:3比例分割
train_size = int(len(self.original_data) * 0.7)
self.train_data = well_balanced_graphs[:train_size]
self.val_data = well_balanced_graphs[train_size:] + sparse_positive_graphs
print(f"训练集: {len(self.train_data)} 个图")
print(f"验证/测试集: {len(self.val_data)} 个图")
return self.train_data, self.val_data
class GANModelBuilder:
"""GAN模型构建器 - 负责创建和配置生成器、判别器"""
def __init__(self, config):
self.config = config
self.generator = None
self.discriminator = None
def build_generator(self, output_dim):
"""构建生成器模型"""
self.generator = Generator(
input_dim=self.config.generator_input_dim,
output_dim=output_dim
).to(device)
self.generator_optimizer = optim.Adam(
self.generator.parameters(),
lr=self.config.learning_rate,
weight_decay=self.config.weight_decay
)
return self.generator, self.generator_optimizer
def build_discriminator(self, input_dim):
"""构建判别器模型"""
self.discriminator = GCN(
nfeat=input_dim,
nhid=self.config.hidden_dim,
nclass=self.config.num_classes,
dropout=self.config.dropout_rate
).to(device)
self.discriminator_optimizer = optim.Adam(
self.discriminator.parameters(),
lr=self.config.learning_rate,
weight_decay=self.config.weight_decay
)
return self.discriminator, self.discriminator_optimizer
class GraphAugmentor:
"""图数据增强器 - 负责生成合成节点和平衡图数据"""
def __init__(self, config, generator, generator_optimizer):
self.config = config
self.generator = generator
self.generator_optimizer = generator_optimizer
def calculate_imbalance_gap(self, labels):
"""计算类别不平衡差距"""
num_negatives = torch.sum(labels == 0).item()
num_positives = torch.sum(labels == 1).item()
return abs(num_negatives - num_positives), num_positives, num_negatives
def generate_synthetic_nodes(self, gap_size, minority_features):
"""生成合成节点特征"""
# 生成随机噪声
noise = torch.randn(gap_size, self.config.generator_input_dim, device=device)
# 使用生成器生成邻接关系
adjacency_matrix = self.generator(noise)
adjacency_matrix = (adjacency_matrix+1)/2
print(f"关联矩阵的形状:{adjacency_matrix.shape}")
# 计算与少数类样本的连接关系
link_relationship = F.softmax(adjacency_matrix[:, :minority_features.shape[0]], dim=1)
print(f"连接关系的形状:{link_relationship.shape}")
# 生成节点特征
synthetic_features = torch.mm(link_relationship, minority_features)
return synthetic_features, adjacency_matrix, link_relationship
def build_augmented_graph(self, original_graph, synthetic_features, link_relationship):
"""构建增强后的图"""
original_features = original_graph.x
original_edge_index = original_graph.edge_index
# 合并特征
augmented_features = torch.cat((original_features, synthetic_features), dim=0)
# 构建新生成节点的边连接
threshold = self.config.threshold
adjacency_mask = (link_relationship > threshold).int()
# 创建节点索引映射
num_original_nodes = original_features.shape[0]
num_synthetic_nodes = synthetic_features.shape[0]
# 构建新的边索引
synthetic_edges = []
nonzero_positions = torch.nonzero(adjacency_mask)
for pos in nonzero_positions:
synthetic_node_id = pos[0].item() + num_original_nodes # synthetic_node_id = 0(合成节点中的索引) + 3(原始少数类个数) = 3(即合成节点加入原始少数类后的全局索引)
original_node_id = pos[1].item() # original_node_id = 1
synthetic_edges.append([synthetic_node_id, original_node_id])
synthetic_edges.append([original_node_id, synthetic_node_id]) # 双向连接 这两步添加边:[3, 1] 和 [1, 3]
if synthetic_edges:
synthetic_edge_index = torch.tensor(synthetic_edges, device=device).t()
augmented_edge_index = torch.cat((original_edge_index, synthetic_edge_index), dim=1)
else:
augmented_edge_index = original_edge_index
return augmented_features, augmented_edge_index
def create_balanced_labels(self, original_labels, num_synthetic_nodes):
"""创建平衡后的标签"""
synthetic_labels = torch.ones((1, num_synthetic_nodes), dtype=original_labels.dtype, device=device)
balanced_labels = torch.cat((original_labels, synthetic_labels), dim=1)
return balanced_labels.flatten()
class GANTrainer:
"""GAN训练器 - 现在实现真正的对抗训练"""
def __init__(self, config, discriminator, discriminator_optimizer,
generator, generator_optimizer, augmentor):
self.config = config
self.discriminator = discriminator
self.discriminator_optimizer = discriminator_optimizer
self.generator = generator
self.generator_optimizer = generator_optimizer
self.augmentor = augmentor
self.training_history = []
def create_data_splits(self, total_positive_samples, total_negative_samples, num_synthetic_nodes):
"""创建训练、验证、测试数据分割"""
# 正样本分割
positive_indices = torch.arange(total_positive_samples)
train_size_pos = int(len(positive_indices) * self.config.train_ratio)
val_size_pos = int(len(positive_indices) * self.config.val_ratio)
permuted_pos = positive_indices[torch.randperm(len(positive_indices))]
idx_train_pos = permuted_pos[:train_size_pos]
idx_val_pos = permuted_pos[train_size_pos:train_size_pos + val_size_pos]
idx_test_pos = permuted_pos[train_size_pos + val_size_pos:]
# 负样本分割
negative_indices = torch.arange(total_negative_samples)
train_size_neg = int(len(negative_indices) * self.config.train_ratio)
val_size_neg = int(len(negative_indices) * self.config.val_ratio)
permuted_neg = negative_indices[torch.randperm(len(negative_indices))]
idx_train_neg = permuted_neg[:train_size_neg]
idx_val_neg = permuted_neg[train_size_neg:train_size_neg + val_size_neg]
idx_test_neg = permuted_neg[train_size_neg + val_size_neg:]
# 合并索引
num_real_samples = train_size_pos + train_size_neg
synthetic_indices = torch.arange(total_positive_samples + total_negative_samples,
total_positive_samples + total_negative_samples + num_synthetic_nodes)
idx_train = torch.cat((idx_train_pos, idx_train_neg, synthetic_indices))
idx_val = torch.cat((idx_val_pos, idx_val_neg))
idx_test = torch.cat((idx_test_pos, idx_test_neg))
print(f"idx_train, idx_val, idx_test, num_real_samples分别为:{idx_train, idx_val, idx_test, num_real_samples}")
return idx_train, idx_val, idx_test, num_real_samples
def train_generator_step(self, original_graph, minority_features, gap_size, minority_indices):
"""生成器训练步骤:学习欺骗判别器"""
self.generator.train()
self.discriminator.eval() # 固定BN/Dropout层
# 1. 清空生成器梯度(判别器梯度不清空,会被backward计算但不会被应用)
self.generator_optimizer.zero_grad()
# 2. 生成合成节点
noise = torch.randn(gap_size, self.config.generator_input_dim, device=device)
adjacency_matrix = self.generator(noise)
adjacency_matrix = (adjacency_matrix + 1) / 2
link_relationship = F.softmax(adjacency_matrix[:, :minority_features.shape[0]], dim=1)
synthetic_features = torch.mm(link_relationship, minority_features)
# 3. 构建增强图
augmented_features, augmented_edge_index = self.augmentor.build_augmented_graph(
original_graph, synthetic_features, link_relationship
)
# 4. 构建稀疏邻接矩阵
sparse_adj = torch.sparse_coo_tensor(
augmented_edge_index,
torch.ones(augmented_edge_index.shape[1]),
size=(augmented_features.shape[0], augmented_features.shape[0]),
dtype=torch.float32,
device=device
)
# 5. 判别器预测(不包裹no_grad,让梯度流回生成器)
_, output_gen, _ = self.discriminator(augmented_features, sparse_adj)
# 6. 生成器损失:让判别器将合成节点误判为真实(标签0)
num_original = original_graph.x.shape[0]
synthetic_start = num_original
synthetic_end = num_original + gap_size
generator_loss = F.nll_loss(
output_gen[synthetic_start:synthetic_end],
torch.zeros(gap_size, dtype=torch.long, device=device)
)
# 7. 反向传播(梯度会流向生成器和判别器参数)
generator_loss.backward()
# 8. 更新生成器参数(判别器参数未更新)
self.generator_optimizer.step()
return generator_loss.item()
def train_discriminator_step(self, original_graph, minority_indices, gap_size,
idx_train, idx_val, idx_test):
"""判别器训练步骤:区分真实和合成节点"""
self.discriminator.train()
self.generator.eval() # 固定生成器
# 1. 清空判别器梯度
self.discriminator_optimizer.zero_grad()
# 2. 重新生成合成节点(使用评估模式的生成器)
with torch.no_grad():
minority_features = original_graph.x[minority_indices]
noise = torch.randn(gap_size, self.config.generator_input_dim, device=device)
adjacency_matrix = self.generator(noise)
adjacency_matrix = (adjacency_matrix + 1) / 2
link_relationship = F.softmax(adjacency_matrix[:, :minority_features.shape[0]], dim=1)
synthetic_features = torch.mm(link_relationship, minority_features)
# 构建增强图和标签
augmented_features, augmented_edge_index = self.augmentor.build_augmented_graph(
original_graph, synthetic_features, link_relationship
)
balanced_labels = self.augmentor.create_balanced_labels(original_graph.y, gap_size)
# 3. 构建稀疏邻接矩阵
sparse_adj = torch.sparse_coo_tensor(
augmented_edge_index,
torch.ones(augmented_edge_index.shape[1]),
size=(augmented_features.shape[0], augmented_features.shape[0]),
dtype=torch.float32,
device=device
)
# 4. 判别器前向传播
output_real, output_gen, output_auc = self.discriminator(augmented_features, sparse_adj)
# 5. 构建真实标签(真实=0,合成=1)
num_real_samples = original_graph.x.shape[0]
real_labels = torch.cat([
torch.zeros(num_real_samples, dtype=torch.long, device=device),
torch.ones(gap_size, dtype=torch.long, device=device)
])
# 6. 计算判别器损失
distance_loss = -euclidean_dist(
original_graph.x[minority_indices],
original_graph.x[minority_indices]
).mean() # 注意:这个损失函数可能需要调整
total_loss = (
F.nll_loss(output_real[idx_train[:num_real_samples]],
balanced_labels[idx_train[:num_real_samples]]) +
F.nll_loss(output_gen[idx_train], real_labels[idx_train]) +
distance_loss
)
# 7. 反向传播和优化
total_loss.backward()
self.discriminator_optimizer.step()
# 8. 验证
self.discriminator.eval()
with torch.no_grad():
output_real, output_gen, output_auc = self.discriminator(augmented_features, sparse_adj)
recall_val, f1_val, auc_val, acc_val, pre_val = accuracy(
output_real[idx_val], balanced_labels[idx_val], output_auc[idx_val]
)
recall_train, f1_train, auc_train, acc_train, pre_train = accuracy(
output_real[idx_train[:num_real_samples]],
balanced_labels[idx_train[:num_real_samples]],
output_auc[idx_train[:num_real_samples]]
)
return {
'discriminator_loss': total_loss.item(),
'train_recall': recall_train, 'train_f1': f1_train, 'train_acc': acc_train,
'val_recall': recall_val, 'val_f1': f1_val, 'val_acc': acc_val,
}
def train_epoch(self, original_graph, minority_indices, gap_size,
idx_train, idx_val, idx_test):
"""完整训练epoch:先训练生成器,再训练判别器"""
minority_features = original_graph.x[minority_indices]
# 1. 训练生成器(让判别器误判)
gen_loss = self.train_generator_step(original_graph, minority_features, gap_size, minority_indices)
# 2. 训练判别器(正确区分)
disc_metrics = self.train_discriminator_step(
original_graph, minority_indices, gap_size, idx_train, idx_val, idx_test
)
# 合并结果
disc_metrics['generator_loss'] = gen_loss
return disc_metrics
class GraphBalanceProcessor:
"""主处理器 - 协调整个图数据平衡流程"""
def __init__(self, config=None):
self.config = config or GraphBalanceConfig()
self.data_analyzer = GraphDataAnalyzer(self.config)
self.model_builder = GANModelBuilder(self.config)
self.augmentor = None
self.trainer = None
self.processed_graphs = []
def process_single_graph(self, graph_data, graph_index):
"""处理单个图数据"""
print(f"\n处理图 {graph_index}:")
# 分析图的平衡情况
analysis = self.data_analyzer.analyze_graph_balance(graph_data)
gap_size = analysis['imbalance_gap']
if gap_size == 0:
print(f" 跳过: 已经平衡")
return graph_data
print(f" 需要生成 {gap_size} 个正类样本")
# 准备数据
features = graph_data.x
labels = graph_data.y
minority_indices = torch.nonzero(labels == 1)[:, 0]
minority_features = features[minority_indices]
# 构建模型
generator, generator_optimizer = self.model_builder.build_generator(
output_dim=minority_features.shape[0] # 注意:这里应该根据少数类特征维度调整
)
discriminator, discriminator_optimizer = self.model_builder.build_discriminator(
input_dim=features.shape[1]
)
# 创建增强器
self.augmentor = GraphAugmentor(self.config, generator, generator_optimizer)
# ***关键修改:创建训练器时传入所有组件***
self.trainer = GANTrainer(
self.config,
discriminator,
discriminator_optimizer,
generator, # 新增:传入生成器
generator_optimizer, # 新增:传入生成器优化器
self.augmentor # 新增:传入增强器
)
# 创建数据分割
idx_train, idx_val, idx_test, num_real_samples = self.trainer.create_data_splits(
analysis['num_positives'], analysis['num_negatives'], gap_size
)
# 训练GAN(现在生成器和判别器都会更新)
print(f" 开始对抗训练...")
best_metrics = None
best_score = 0
for epoch in range(self.config.num_epochs):
metrics = self.trainer.train_epoch(
graph_data, minority_indices, gap_size, idx_train, idx_val, idx_test
)
if (epoch+1) % 10 == 0:
print(f" Epoch {epoch+1}/{self.config.num_epochs} | "
f"Gen Loss: {metrics['generator_loss']:.4f} | "
f"Disc Loss: {metrics['discriminator_loss']:.4f}")
# 计算综合评分
current_score = (metrics['val_recall'] + metrics['val_acc']) / 2
if current_score > best_score:
best_score = current_score
best_metrics = metrics
if best_metrics:
print(f" 最佳验证性能: Recall={best_metrics['val_recall']:.4f}, "
f"F1={best_metrics['val_f1']:.4f}, Acc={best_metrics['val_acc']:.4f}")
# 使用最终生成器生成平衡后的图
with torch.no_grad():
synthetic_features, _, link_relationship = self.augmentor.generate_synthetic_nodes(
gap_size, graph_data.x[minority_indices]
)
augmented_features, augmented_edge_index = self.augmentor.build_augmented_graph(
graph_data, synthetic_features, link_relationship
)
balanced_labels = self.augmentor.create_balanced_labels(labels, gap_size)
balanced_graph = Data(
x=augmented_features,
edge_index=augmented_edge_index,
y=balanced_labels.view(1, -1)
)
print(f" 平衡后的图:{balanced_graph}")
return balanced_graph
def process_all_graphs(self):
"""处理所有图数据"""
print("=" * 80)
print("开始处理图数据平衡任务")
print("=" * 80)
# 加载数据
if not self.data_analyzer.load_data(self.config.input_path):
return False
# 分割数据
well_balanced, sparse_positive = self.data_analyzer.split_data_by_balance()
train_data, val_data = self.data_analyzer.create_train_test_split(well_balanced, sparse_positive)
# 保存验证数据
# torch.save(val_data, self.config.val_output_path)
print(f"验证数据已保存到: {self.config.val_output_path}")
# 处理训练数据
processed_train_data = []
for i, graph in enumerate(train_data):
processed_graph = self.process_single_graph(graph, i)
if processed_graph is not None:
processed_train_data.append(processed_graph)
break
# 保存处理后的训练数据
# torch.save(processed_train_data, self.config.train_output_path)
print(f"\n训练数据已保存到: {self.config.train_output_path}")
print(f"成功处理 {len(processed_train_data)} 个图")
return True
if __name__ == "__main__":
"""主函数"""
# 创建配置
config = GraphBalanceConfig()
# 创建处理器
processor = GraphBalanceProcessor(config)
# 执行处理
success = processor.process_all_graphs()
if success:
print("\n" + "=" * 80)
print("【success】---图数据平衡处理完成!")
print("=" * 80)
else:
print("\n出了点问题呢!---处理失败,请检查错误信息,请大侠从头再来!")
暂无评论