get_balanced_data.py
import warnings
import torch
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from torch_geometric.data import Data
import random

# 配置警告和随机种子
warnings.filterwarnings("ignore")
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

# 设备配置
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"使用设备:{device}")

# 导入模型和工具函数
from models import Generator, GCN
from utils import euclidean_dist, accuracy


class GraphBalanceConfig:
    """配置类 - 集中管理所有超参数"""
    
    def __init__(self):
        # 模型参数
        self.generator_input_dim = 300  # 生成器输入噪声维度
        self.hidden_dim = 1280  # GCN隐藏层维度
        self.num_classes = 2  # 分类类别数
        self.dropout_rate = 0.1  # Dropout率
        
        # 训练参数
        self.learning_rate = 0.0001  # 学习率
        self.weight_decay = 0.0005  # 权重衰减
        self.num_epochs = 300  # 每个图的训练轮数
        self.num_iterations = 10  # 生成器迭代次数
        
        # 数据参数
        self.threshold = 0.038  # 邻接矩阵阈值
        self.train_ratio = 0.6  # 训练集比例
        self.val_ratio = 0.2  # 验证集比例
        self.test_ratio = 0.2  # 测试集比例
        self.min_positive_samples = 5  # 最小正类样本数
        
        # 路径配置
        self.input_path = './ImGAGN/573.pt'
        self.train_output_path = './ImGAGN/573_train.pt'
        self.val_output_path = './ImGAGN/573_val.pt'


class GraphDataAnalyzer:
    """图数据分析器 - 负责数据加载、分析和预处理"""
    
    def __init__(self, config):
        self.config = config
        self.original_data = None
        self.train_data = None
        self.val_data = None
        
    def load_data(self, file_path):
        """加载图数据文件"""
        try:
            print(f"正在加载数据文件: {file_path}")
            self.original_data = torch.load(file_path, weights_only=False, map_location=device)
            print(f"成功加载 {len(self.original_data)} 个图数据对象")
            return True
        except Exception as e:
            print(f"加载数据失败: {e}")
            return False
    
    def analyze_graph_balance(self, graph_data):
        """分析单个图的类别平衡情况"""
        labels = graph_data.y
        num_negatives = torch.sum(labels == 0).item()
        num_positives = torch.sum(labels == 1).item()
        imbalance_gap = abs(num_negatives - num_positives)
        
        return {
            'num_nodes': graph_data.num_nodes,
            'num_negatives': num_negatives,
            'num_positives': num_positives,
            'imbalance_gap': imbalance_gap,
            'positive_ratio': num_positives / graph_data.num_nodes if graph_data.num_nodes > 0 else 0
        }
    
    def split_data_by_balance(self):
        """根据正类样本数量分割数据"""
        if not self.original_data:
            raise ValueError("请先加载数据")
        
        well_balanced_graphs = []  # 正类样本 >= 5的图
        sparse_positive_graphs = []  # 正类样本 < 5的图
        
        for graph in self.original_data:
            analysis = self.analyze_graph_balance(graph)
            if analysis['num_positives'] >= self.config.min_positive_samples:
                well_balanced_graphs.append(graph)
            else:
                sparse_positive_graphs.append(graph)
        
        print(f"数据分割结果:")
        print(f"  - 正类丰富的图: {len(well_balanced_graphs)} 个")
        print(f"  - 正类稀疏的图: {len(sparse_positive_graphs)} 个")
        
        return well_balanced_graphs, sparse_positive_graphs
    
    def create_train_test_split(self, well_balanced_graphs, sparse_positive_graphs):
        """创建训练测试分割"""
        # 打乱顺序
        random.shuffle(well_balanced_graphs)
        
        # 按照7:3比例分割
        train_size = int(len(self.original_data) * 0.7)
        self.train_data = well_balanced_graphs[:train_size]
        self.val_data = well_balanced_graphs[train_size:] + sparse_positive_graphs
        
        print(f"训练集: {len(self.train_data)} 个图")
        print(f"验证/测试集: {len(self.val_data)} 个图")
        
        return self.train_data, self.val_data


class GANModelBuilder:
    """GAN模型构建器 - 负责创建和配置生成器、判别器"""
    
    def __init__(self, config):
        self.config = config
        self.generator = None
        self.discriminator = None
        
    def build_generator(self, output_dim):
        """构建生成器模型"""
        self.generator = Generator(
            input_dim=self.config.generator_input_dim,
            output_dim=output_dim
        ).to(device)
        
        self.generator_optimizer = optim.Adam(
            self.generator.parameters(),
            lr=self.config.learning_rate,
            weight_decay=self.config.weight_decay
        )
        
        return self.generator, self.generator_optimizer
    
    def build_discriminator(self, input_dim):
        """构建判别器模型"""
        self.discriminator = GCN(
            nfeat=input_dim,
            nhid=self.config.hidden_dim,
            nclass=self.config.num_classes,
            dropout=self.config.dropout_rate
        ).to(device)
        
        self.discriminator_optimizer = optim.Adam(
            self.discriminator.parameters(),
            lr=self.config.learning_rate,
            weight_decay=self.config.weight_decay
        )
        
        return self.discriminator, self.discriminator_optimizer


class GraphAugmentor:
    """图数据增强器 - 负责生成合成节点和平衡图数据"""
    
    def __init__(self, config, generator, generator_optimizer):
        self.config = config
        self.generator = generator
        self.generator_optimizer = generator_optimizer
        
    def calculate_imbalance_gap(self, labels):
        """计算类别不平衡差距"""
        num_negatives = torch.sum(labels == 0).item()
        num_positives = torch.sum(labels == 1).item()
        return abs(num_negatives - num_positives), num_positives, num_negatives
    
    def generate_synthetic_nodes(self, gap_size, minority_features):
        """生成合成节点特征"""
        
        # 生成随机噪声
        noise = torch.randn(gap_size, self.config.generator_input_dim, device=device)

        # 使用生成器生成邻接关系
        adjacency_matrix = self.generator(noise)
        adjacency_matrix = (adjacency_matrix+1)/2
        print(f"关联矩阵的形状:{adjacency_matrix.shape}")
        
        # 计算与少数类样本的连接关系
        link_relationship = F.softmax(adjacency_matrix[:, :minority_features.shape[0]], dim=1)
        print(f"连接关系的形状:{link_relationship.shape}")

        # 生成节点特征
        synthetic_features = torch.mm(link_relationship, minority_features)
        
        return synthetic_features, adjacency_matrix, link_relationship
    
    def build_augmented_graph(self, original_graph, synthetic_features, link_relationship):
        """构建增强后的图"""
        original_features = original_graph.x
        original_edge_index = original_graph.edge_index
        
        # 合并特征
        augmented_features = torch.cat((original_features, synthetic_features), dim=0)
        
        # 构建新生成节点的边连接
        threshold = self.config.threshold
        adjacency_mask = (link_relationship > threshold).int()
        
        # 创建节点索引映射
        num_original_nodes = original_features.shape[0]
        num_synthetic_nodes = synthetic_features.shape[0]
        
        # 构建新的边索引
        synthetic_edges = []
        nonzero_positions = torch.nonzero(adjacency_mask)
        
        for pos in nonzero_positions:
            synthetic_node_id = pos[0].item() + num_original_nodes # synthetic_node_id = 0(合成节点中的索引) + 3(原始少数类个数) = 3(即合成节点加入原始少数类后的全局索引)
            original_node_id = pos[1].item() # original_node_id = 1
            synthetic_edges.append([synthetic_node_id, original_node_id])
            synthetic_edges.append([original_node_id, synthetic_node_id])  # 双向连接 这两步添加边:[3, 1] 和 [1, 3]
        
        if synthetic_edges:
            synthetic_edge_index = torch.tensor(synthetic_edges, device=device).t()
            augmented_edge_index = torch.cat((original_edge_index, synthetic_edge_index), dim=1)
        else:
            augmented_edge_index = original_edge_index
        
        return augmented_features, augmented_edge_index
    
    def create_balanced_labels(self, original_labels, num_synthetic_nodes):
        """创建平衡后的标签"""
        synthetic_labels = torch.ones((1, num_synthetic_nodes), dtype=original_labels.dtype, device=device)
        balanced_labels = torch.cat((original_labels, synthetic_labels), dim=1)
        return balanced_labels.flatten()


class GANTrainer:
    """GAN训练器 - 现在实现真正的对抗训练"""
    
    def __init__(self, config, discriminator, discriminator_optimizer, 
                 generator, generator_optimizer, augmentor):
        self.config = config
        self.discriminator = discriminator
        self.discriminator_optimizer = discriminator_optimizer
        self.generator = generator
        self.generator_optimizer = generator_optimizer
        self.augmentor = augmentor
        self.training_history = []
 
    def create_data_splits(self, total_positive_samples, total_negative_samples, num_synthetic_nodes):
        """创建训练、验证、测试数据分割"""
        # 正样本分割
        positive_indices = torch.arange(total_positive_samples)
        train_size_pos = int(len(positive_indices) * self.config.train_ratio)
        val_size_pos = int(len(positive_indices) * self.config.val_ratio)
        
        permuted_pos = positive_indices[torch.randperm(len(positive_indices))]
        idx_train_pos = permuted_pos[:train_size_pos]
        idx_val_pos = permuted_pos[train_size_pos:train_size_pos + val_size_pos]
        idx_test_pos = permuted_pos[train_size_pos + val_size_pos:]
        
        # 负样本分割
        negative_indices = torch.arange(total_negative_samples)
        train_size_neg = int(len(negative_indices) * self.config.train_ratio)
        val_size_neg = int(len(negative_indices) * self.config.val_ratio)
        
        permuted_neg = negative_indices[torch.randperm(len(negative_indices))]
        idx_train_neg = permuted_neg[:train_size_neg]
        idx_val_neg = permuted_neg[train_size_neg:train_size_neg + val_size_neg]
        idx_test_neg = permuted_neg[train_size_neg + val_size_neg:]
        
        # 合并索引
        num_real_samples = train_size_pos + train_size_neg
        synthetic_indices = torch.arange(total_positive_samples + total_negative_samples, 
                                       total_positive_samples + total_negative_samples + num_synthetic_nodes)
        
        idx_train = torch.cat((idx_train_pos, idx_train_neg, synthetic_indices))
        idx_val = torch.cat((idx_val_pos, idx_val_neg))
        idx_test = torch.cat((idx_test_pos, idx_test_neg))
        print(f"idx_train, idx_val, idx_test, num_real_samples分别为:{idx_train, idx_val, idx_test, num_real_samples}")
        return idx_train, idx_val, idx_test, num_real_samples

    def train_generator_step(self, original_graph, minority_features, gap_size, minority_indices):
        """生成器训练步骤:学习欺骗判别器"""
        self.generator.train()
        self.discriminator.eval()  # 固定BN/Dropout层

        # 1. 清空生成器梯度(判别器梯度不清空,会被backward计算但不会被应用)
        self.generator_optimizer.zero_grad()

        # 2. 生成合成节点
        noise = torch.randn(gap_size, self.config.generator_input_dim, device=device)
        adjacency_matrix = self.generator(noise)
        adjacency_matrix = (adjacency_matrix + 1) / 2
        
        link_relationship = F.softmax(adjacency_matrix[:, :minority_features.shape[0]], dim=1)
        synthetic_features = torch.mm(link_relationship, minority_features)

        # 3. 构建增强图
        augmented_features, augmented_edge_index = self.augmentor.build_augmented_graph(
            original_graph, synthetic_features, link_relationship
        )

        # 4. 构建稀疏邻接矩阵
        sparse_adj = torch.sparse_coo_tensor(
            augmented_edge_index,
            torch.ones(augmented_edge_index.shape[1]),
            size=(augmented_features.shape[0], augmented_features.shape[0]),
            dtype=torch.float32,
            device=device
        )

        # 5. 判别器预测(不包裹no_grad,让梯度流回生成器)
        _, output_gen, _ = self.discriminator(augmented_features, sparse_adj)
        
        # 6. 生成器损失:让判别器将合成节点误判为真实(标签0)
        num_original = original_graph.x.shape[0]
        synthetic_start = num_original
        synthetic_end = num_original + gap_size
        
        generator_loss = F.nll_loss(
            output_gen[synthetic_start:synthetic_end],
            torch.zeros(gap_size, dtype=torch.long, device=device)
        )
        
        # 7. 反向传播(梯度会流向生成器和判别器参数)
        generator_loss.backward()
        
        # 8. 更新生成器参数(判别器参数未更新)
        self.generator_optimizer.step()
        
        return generator_loss.item()
    
    def train_discriminator_step(self, original_graph, minority_indices, gap_size, 
                                idx_train, idx_val, idx_test):
        """判别器训练步骤:区分真实和合成节点"""
        self.discriminator.train()
        self.generator.eval()  # 固定生成器

        # 1. 清空判别器梯度
        self.discriminator_optimizer.zero_grad()

        # 2. 重新生成合成节点(使用评估模式的生成器)
        with torch.no_grad():
            minority_features = original_graph.x[minority_indices]
            noise = torch.randn(gap_size, self.config.generator_input_dim, device=device)
            adjacency_matrix = self.generator(noise)
            adjacency_matrix = (adjacency_matrix + 1) / 2
            
            link_relationship = F.softmax(adjacency_matrix[:, :minority_features.shape[0]], dim=1)
            synthetic_features = torch.mm(link_relationship, minority_features)
            
            # 构建增强图和标签
            augmented_features, augmented_edge_index = self.augmentor.build_augmented_graph(
                original_graph, synthetic_features, link_relationship
            )
            balanced_labels = self.augmentor.create_balanced_labels(original_graph.y, gap_size)

        # 3. 构建稀疏邻接矩阵
        sparse_adj = torch.sparse_coo_tensor(
            augmented_edge_index,
            torch.ones(augmented_edge_index.shape[1]),
            size=(augmented_features.shape[0], augmented_features.shape[0]),
            dtype=torch.float32,
            device=device
        )

        # 4. 判别器前向传播
        output_real, output_gen, output_auc = self.discriminator(augmented_features, sparse_adj)
        
        # 5. 构建真实标签(真实=0,合成=1)
        num_real_samples = original_graph.x.shape[0]
        real_labels = torch.cat([
            torch.zeros(num_real_samples, dtype=torch.long, device=device),
            torch.ones(gap_size, dtype=torch.long, device=device)
        ])

        # 6. 计算判别器损失
        distance_loss = -euclidean_dist(
            original_graph.x[minority_indices], 
            original_graph.x[minority_indices]
        ).mean()  # 注意:这个损失函数可能需要调整
        
        total_loss = (
            F.nll_loss(output_real[idx_train[:num_real_samples]], 
                      balanced_labels[idx_train[:num_real_samples]]) +
            F.nll_loss(output_gen[idx_train], real_labels[idx_train]) +
            distance_loss
        )
        
        # 7. 反向传播和优化
        total_loss.backward()
        self.discriminator_optimizer.step()
        
        # 8. 验证
        self.discriminator.eval()
        with torch.no_grad():
            output_real, output_gen, output_auc = self.discriminator(augmented_features, sparse_adj)
            
            recall_val, f1_val, auc_val, acc_val, pre_val = accuracy(
                output_real[idx_val], balanced_labels[idx_val], output_auc[idx_val]
            )
            recall_train, f1_train, auc_train, acc_train, pre_train = accuracy(
                output_real[idx_train[:num_real_samples]], 
                balanced_labels[idx_train[:num_real_samples]], 
                output_auc[idx_train[:num_real_samples]]
            )
        
        return {
            'discriminator_loss': total_loss.item(),
            'train_recall': recall_train, 'train_f1': f1_train, 'train_acc': acc_train,
            'val_recall': recall_val, 'val_f1': f1_val, 'val_acc': acc_val,
        }
    
    def train_epoch(self, original_graph, minority_indices, gap_size, 
                   idx_train, idx_val, idx_test):
        """完整训练epoch:先训练生成器,再训练判别器"""
        minority_features = original_graph.x[minority_indices]
        
        # 1. 训练生成器(让判别器误判)
        gen_loss = self.train_generator_step(original_graph, minority_features, gap_size, minority_indices)
        
        # 2. 训练判别器(正确区分)
        disc_metrics = self.train_discriminator_step(
            original_graph, minority_indices, gap_size, idx_train, idx_val, idx_test
        )
        
        # 合并结果
        disc_metrics['generator_loss'] = gen_loss
        return disc_metrics
    
    
class GraphBalanceProcessor:
    """主处理器 - 协调整个图数据平衡流程"""
    
    def __init__(self, config=None):
        self.config = config or GraphBalanceConfig()
        self.data_analyzer = GraphDataAnalyzer(self.config)
        self.model_builder = GANModelBuilder(self.config)
        self.augmentor = None
        self.trainer = None
        self.processed_graphs = []
        
    def process_single_graph(self, graph_data, graph_index):
        """处理单个图数据"""
        print(f"\n处理图 {graph_index}:")
        
        # 分析图的平衡情况
        analysis = self.data_analyzer.analyze_graph_balance(graph_data)
        gap_size = analysis['imbalance_gap']
        if gap_size == 0:
            print(f"  跳过: 已经平衡")
            return graph_data
        
        print(f"  需要生成 {gap_size} 个正类样本")
        
        # 准备数据
        features = graph_data.x
        labels = graph_data.y
        minority_indices = torch.nonzero(labels == 1)[:, 0]
        minority_features = features[minority_indices]

        # 构建模型
        generator, generator_optimizer = self.model_builder.build_generator(
            output_dim=minority_features.shape[0]  # 注意:这里应该根据少数类特征维度调整
        )
        discriminator, discriminator_optimizer = self.model_builder.build_discriminator(
            input_dim=features.shape[1]
        )
        
        # 创建增强器
        self.augmentor = GraphAugmentor(self.config, generator, generator_optimizer)
        
        # ***关键修改:创建训练器时传入所有组件***
        self.trainer = GANTrainer(
            self.config,
            discriminator,
            discriminator_optimizer,
            generator,  # 新增:传入生成器
            generator_optimizer,  # 新增:传入生成器优化器
            self.augmentor  # 新增:传入增强器
        )
        
        # 创建数据分割
        idx_train, idx_val, idx_test, num_real_samples = self.trainer.create_data_splits(
            analysis['num_positives'], analysis['num_negatives'], gap_size
        )
        
        # 训练GAN(现在生成器和判别器都会更新)
        print(f"  开始对抗训练...")
        best_metrics = None
        best_score = 0
        
        for epoch in range(self.config.num_epochs):
            metrics = self.trainer.train_epoch(
                graph_data, minority_indices, gap_size, idx_train, idx_val, idx_test
            )
            
            if (epoch+1) % 10 == 0:
                print(f"  Epoch {epoch+1}/{self.config.num_epochs} | "
                      f"Gen Loss: {metrics['generator_loss']:.4f} | "
                      f"Disc Loss: {metrics['discriminator_loss']:.4f}")
            
            # 计算综合评分
            current_score = (metrics['val_recall'] + metrics['val_acc']) / 2
            if current_score > best_score:
                best_score = current_score
                best_metrics = metrics
        
        if best_metrics:
            print(f"  最佳验证性能: Recall={best_metrics['val_recall']:.4f}, "
                  f"F1={best_metrics['val_f1']:.4f}, Acc={best_metrics['val_acc']:.4f}")
        
        # 使用最终生成器生成平衡后的图
        with torch.no_grad():
            synthetic_features, _, link_relationship = self.augmentor.generate_synthetic_nodes(
                gap_size, graph_data.x[minority_indices]
            )
            augmented_features, augmented_edge_index = self.augmentor.build_augmented_graph(
                graph_data, synthetic_features, link_relationship
            )
            balanced_labels = self.augmentor.create_balanced_labels(labels, gap_size)
        
        balanced_graph = Data(
            x=augmented_features,
            edge_index=augmented_edge_index,
            y=balanced_labels.view(1, -1)
        )
        print(f"  平衡后的图:{balanced_graph}")
        return balanced_graph
    
    def process_all_graphs(self):
        """处理所有图数据"""
        print("=" * 80)
        print("开始处理图数据平衡任务")
        print("=" * 80)
        
        # 加载数据
        if not self.data_analyzer.load_data(self.config.input_path):
            return False
        
        # 分割数据
        well_balanced, sparse_positive = self.data_analyzer.split_data_by_balance()
        train_data, val_data = self.data_analyzer.create_train_test_split(well_balanced, sparse_positive)
        
        # 保存验证数据
        # torch.save(val_data, self.config.val_output_path)
        print(f"验证数据已保存到: {self.config.val_output_path}")
        
        # 处理训练数据
        processed_train_data = []
        
        for i, graph in enumerate(train_data):
            processed_graph = self.process_single_graph(graph, i)
            if processed_graph is not None:
                processed_train_data.append(processed_graph)
            break
        
        # 保存处理后的训练数据
        # torch.save(processed_train_data, self.config.train_output_path)
        print(f"\n训练数据已保存到: {self.config.train_output_path}")
        print(f"成功处理 {len(processed_train_data)} 个图")
        
        return True
    
if __name__ == "__main__":
    """主函数"""
    # 创建配置
    config = GraphBalanceConfig()
    
    # 创建处理器
    processor = GraphBalanceProcessor(config)
    
    # 执行处理
    success = processor.process_all_graphs()
    
    if success:
        print("\n" + "=" * 80)
        print("【success】---图数据平衡处理完成!")
        print("=" * 80)
    else:
        print("\n出了点问题呢!---处理失败,请检查错误信息,请大侠从头再来!")
暂无评论

发送评论 编辑评论


				
|´・ω・)ノ
ヾ(≧∇≦*)ゝ
(☆ω☆)
(╯‵□′)╯︵┴─┴
 ̄﹃ ̄
(/ω\)
∠( ᐛ 」∠)_
(๑•̀ㅁ•́ฅ)
→_→
୧(๑•̀⌄•́๑)૭
٩(ˊᗜˋ*)و
(ノ°ο°)ノ
(´இ皿இ`)
⌇●﹏●⌇
(ฅ´ω`ฅ)
(╯°A°)╯︵○○○
φ( ̄∇ ̄o)
ヾ(´・ ・`。)ノ"
( ง ᵒ̌皿ᵒ̌)ง⁼³₌₃
(ó﹏ò。)
Σ(っ °Д °;)っ
( ,,´・ω・)ノ"(´っω・`。)
╮(╯▽╰)╭
o(*////▽////*)q
>﹏<
( ๑´•ω•) "(ㆆᴗㆆ)
😂
😀
😅
😊
🙂
🙃
😌
😍
😘
😜
😝
😏
😒
🙄
😳
😡
😔
😫
😱
😭
💩
👻
🙌
🖕
👍
👫
👬
👭
🌚
🌝
🙈
💊
😶
🙏
🍦
🍉
😣
Source: github.com/k4yt3x/flowerhd
颜文字
Emoji
小恐龙
花!
上一篇
下一篇