{"id":172,"date":"2025-12-08T22:34:08","date_gmt":"2025-12-08T14:34:08","guid":{"rendered":"https:\/\/snakesleep.work\/?p=172"},"modified":"2025-12-08T22:34:08","modified_gmt":"2025-12-08T14:34:08","slug":"get_balanced_data-py","status":"publish","type":"post","link":"https:\/\/snakesleep.work\/?p=172","title":{"rendered":"get_balanced_data.py"},"content":{"rendered":"\n<pre class=\"wp-block-code\"><code>import warnings\nimport torch\nimport torch.nn.functional as F\nimport torch.optim as optim\nimport numpy as np\nfrom torch_geometric.data import Data\nimport random\n\n# \u914d\u7f6e\u8b66\u544a\u548c\u968f\u673a\u79cd\u5b50\nwarnings.filterwarnings(\"ignore\")\nrandom.seed(42)\nnp.random.seed(42)\ntorch.manual_seed(42)\n\n# \u8bbe\u5907\u914d\u7f6e\ndevice = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\nprint(f\"\u4f7f\u7528\u8bbe\u5907\uff1a{device}\")\n\n# \u5bfc\u5165\u6a21\u578b\u548c\u5de5\u5177\u51fd\u6570\nfrom models import Generator, GCN\nfrom utils import euclidean_dist, accuracy\n\n\nclass GraphBalanceConfig:\n    \"\"\"\u914d\u7f6e\u7c7b - \u96c6\u4e2d\u7ba1\u7406\u6240\u6709\u8d85\u53c2\u6570\"\"\"\n    \n    def __init__(self):\n        # \u6a21\u578b\u53c2\u6570\n        self.generator_input_dim = 300  # \u751f\u6210\u5668\u8f93\u5165\u566a\u58f0\u7ef4\u5ea6\n        self.hidden_dim = 1280  # GCN\u9690\u85cf\u5c42\u7ef4\u5ea6\n        self.num_classes = 2  # \u5206\u7c7b\u7c7b\u522b\u6570\n        self.dropout_rate = 0.1  # Dropout\u7387\n        \n        # \u8bad\u7ec3\u53c2\u6570\n        self.learning_rate = 0.0001  # \u5b66\u4e60\u7387\n        self.weight_decay = 0.0005  # \u6743\u91cd\u8870\u51cf\n        self.num_epochs = 300  # \u6bcf\u4e2a\u56fe\u7684\u8bad\u7ec3\u8f6e\u6570\n        self.num_iterations = 10  # \u751f\u6210\u5668\u8fed\u4ee3\u6b21\u6570\n        \n        # \u6570\u636e\u53c2\u6570\n        self.threshold = 0.038  # \u90bb\u63a5\u77e9\u9635\u9608\u503c\n        self.train_ratio = 0.6  # \u8bad\u7ec3\u96c6\u6bd4\u4f8b\n        self.val_ratio = 0.2  # \u9a8c\u8bc1\u96c6\u6bd4\u4f8b\n        self.test_ratio = 0.2  # \u6d4b\u8bd5\u96c6\u6bd4\u4f8b\n        self.min_positive_samples = 5  # \u6700\u5c0f\u6b63\u7c7b\u6837\u672c\u6570\n        \n        # \u8def\u5f84\u914d\u7f6e\n        self.input_path = '.\/ImGAGN\/573.pt'\n        self.train_output_path = '.\/ImGAGN\/573_train.pt'\n        self.val_output_path = '.\/ImGAGN\/573_val.pt'\n\n\nclass GraphDataAnalyzer:\n    \"\"\"\u56fe\u6570\u636e\u5206\u6790\u5668 - \u8d1f\u8d23\u6570\u636e\u52a0\u8f7d\u3001\u5206\u6790\u548c\u9884\u5904\u7406\"\"\"\n    \n    def __init__(self, config):\n        self.config = config\n        self.original_data = None\n        self.train_data = None\n        self.val_data = None\n        \n    def load_data(self, file_path):\n        \"\"\"\u52a0\u8f7d\u56fe\u6570\u636e\u6587\u4ef6\"\"\"\n        try:\n            print(f\"\u6b63\u5728\u52a0\u8f7d\u6570\u636e\u6587\u4ef6: {file_path}\")\n            self.original_data = torch.load(file_path, weights_only=False, map_location=device)\n            print(f\"\u6210\u529f\u52a0\u8f7d {len(self.original_data)} \u4e2a\u56fe\u6570\u636e\u5bf9\u8c61\")\n            return True\n        except Exception as e:\n            print(f\"\u52a0\u8f7d\u6570\u636e\u5931\u8d25: {e}\")\n            return False\n    \n    def analyze_graph_balance(self, graph_data):\n        \"\"\"\u5206\u6790\u5355\u4e2a\u56fe\u7684\u7c7b\u522b\u5e73\u8861\u60c5\u51b5\"\"\"\n        labels = graph_data.y\n        num_negatives = torch.sum(labels == 0).item()\n        num_positives = torch.sum(labels == 1).item()\n        imbalance_gap = abs(num_negatives - num_positives)\n        \n        return {\n            'num_nodes': graph_data.num_nodes,\n            'num_negatives': num_negatives,\n            'num_positives': num_positives,\n            'imbalance_gap': imbalance_gap,\n            'positive_ratio': num_positives \/ graph_data.num_nodes if graph_data.num_nodes > 0 else 0\n        }\n    \n    def split_data_by_balance(self):\n        \"\"\"\u6839\u636e\u6b63\u7c7b\u6837\u672c\u6570\u91cf\u5206\u5272\u6570\u636e\"\"\"\n        if not self.original_data:\n            raise ValueError(\"\u8bf7\u5148\u52a0\u8f7d\u6570\u636e\")\n        \n        well_balanced_graphs = &#91;]  # \u6b63\u7c7b\u6837\u672c >= 5\u7684\u56fe\n        sparse_positive_graphs = &#91;]  # \u6b63\u7c7b\u6837\u672c &lt; 5\u7684\u56fe\n        \n        for graph in self.original_data:\n            analysis = self.analyze_graph_balance(graph)\n            if analysis&#91;'num_positives'] >= self.config.min_positive_samples:\n                well_balanced_graphs.append(graph)\n            else:\n                sparse_positive_graphs.append(graph)\n        \n        print(f\"\u6570\u636e\u5206\u5272\u7ed3\u679c:\")\n        print(f\"  - \u6b63\u7c7b\u4e30\u5bcc\u7684\u56fe: {len(well_balanced_graphs)} \u4e2a\")\n        print(f\"  - \u6b63\u7c7b\u7a00\u758f\u7684\u56fe: {len(sparse_positive_graphs)} \u4e2a\")\n        \n        return well_balanced_graphs, sparse_positive_graphs\n    \n    def create_train_test_split(self, well_balanced_graphs, sparse_positive_graphs):\n        \"\"\"\u521b\u5efa\u8bad\u7ec3\u6d4b\u8bd5\u5206\u5272\"\"\"\n        # \u6253\u4e71\u987a\u5e8f\n        random.shuffle(well_balanced_graphs)\n        \n        # \u6309\u71677:3\u6bd4\u4f8b\u5206\u5272\n        train_size = int(len(self.original_data) * 0.7)\n        self.train_data = well_balanced_graphs&#91;:train_size]\n        self.val_data = well_balanced_graphs&#91;train_size:] + sparse_positive_graphs\n        \n        print(f\"\u8bad\u7ec3\u96c6: {len(self.train_data)} \u4e2a\u56fe\")\n        print(f\"\u9a8c\u8bc1\/\u6d4b\u8bd5\u96c6: {len(self.val_data)} \u4e2a\u56fe\")\n        \n        return self.train_data, self.val_data\n\n\nclass GANModelBuilder:\n    \"\"\"GAN\u6a21\u578b\u6784\u5efa\u5668 - \u8d1f\u8d23\u521b\u5efa\u548c\u914d\u7f6e\u751f\u6210\u5668\u3001\u5224\u522b\u5668\"\"\"\n    \n    def __init__(self, config):\n        self.config = config\n        self.generator = None\n        self.discriminator = None\n        \n    def build_generator(self, output_dim):\n        \"\"\"\u6784\u5efa\u751f\u6210\u5668\u6a21\u578b\"\"\"\n        self.generator = Generator(\n            input_dim=self.config.generator_input_dim,\n            output_dim=output_dim\n        ).to(device)\n        \n        self.generator_optimizer = optim.Adam(\n            self.generator.parameters(),\n            lr=self.config.learning_rate,\n            weight_decay=self.config.weight_decay\n        )\n        \n        return self.generator, self.generator_optimizer\n    \n    def build_discriminator(self, input_dim):\n        \"\"\"\u6784\u5efa\u5224\u522b\u5668\u6a21\u578b\"\"\"\n        self.discriminator = GCN(\n            nfeat=input_dim,\n            nhid=self.config.hidden_dim,\n            nclass=self.config.num_classes,\n            dropout=self.config.dropout_rate\n        ).to(device)\n        \n        self.discriminator_optimizer = optim.Adam(\n            self.discriminator.parameters(),\n            lr=self.config.learning_rate,\n            weight_decay=self.config.weight_decay\n        )\n        \n        return self.discriminator, self.discriminator_optimizer\n\n\nclass GraphAugmentor:\n    \"\"\"\u56fe\u6570\u636e\u589e\u5f3a\u5668 - \u8d1f\u8d23\u751f\u6210\u5408\u6210\u8282\u70b9\u548c\u5e73\u8861\u56fe\u6570\u636e\"\"\"\n    \n    def __init__(self, config, generator, generator_optimizer):\n        self.config = config\n        self.generator = generator\n        self.generator_optimizer = generator_optimizer\n        \n    def calculate_imbalance_gap(self, labels):\n        \"\"\"\u8ba1\u7b97\u7c7b\u522b\u4e0d\u5e73\u8861\u5dee\u8ddd\"\"\"\n        num_negatives = torch.sum(labels == 0).item()\n        num_positives = torch.sum(labels == 1).item()\n        return abs(num_negatives - num_positives), num_positives, num_negatives\n    \n    def generate_synthetic_nodes(self, gap_size, minority_features):\n        \"\"\"\u751f\u6210\u5408\u6210\u8282\u70b9\u7279\u5f81\"\"\"\n        \n        # \u751f\u6210\u968f\u673a\u566a\u58f0\n        noise = torch.randn(gap_size, self.config.generator_input_dim, device=device)\n\n        # \u4f7f\u7528\u751f\u6210\u5668\u751f\u6210\u90bb\u63a5\u5173\u7cfb\n        adjacency_matrix = self.generator(noise)\n        adjacency_matrix = (adjacency_matrix+1)\/2\n        print(f\"\u5173\u8054\u77e9\u9635\u7684\u5f62\u72b6\uff1a{adjacency_matrix.shape}\")\n        \n        # \u8ba1\u7b97\u4e0e\u5c11\u6570\u7c7b\u6837\u672c\u7684\u8fde\u63a5\u5173\u7cfb\n        link_relationship = F.softmax(adjacency_matrix&#91;:, :minority_features.shape&#91;0]], dim=1)\n        print(f\"\u8fde\u63a5\u5173\u7cfb\u7684\u5f62\u72b6\uff1a{link_relationship.shape}\")\n\n        # \u751f\u6210\u8282\u70b9\u7279\u5f81\n        synthetic_features = torch.mm(link_relationship, minority_features)\n        \n        return synthetic_features, adjacency_matrix, link_relationship\n    \n    def build_augmented_graph(self, original_graph, synthetic_features, link_relationship):\n        \"\"\"\u6784\u5efa\u589e\u5f3a\u540e\u7684\u56fe\"\"\"\n        original_features = original_graph.x\n        original_edge_index = original_graph.edge_index\n        \n        # \u5408\u5e76\u7279\u5f81\n        augmented_features = torch.cat((original_features, synthetic_features), dim=0)\n        \n        # \u6784\u5efa\u65b0\u751f\u6210\u8282\u70b9\u7684\u8fb9\u8fde\u63a5\n        threshold = self.config.threshold\n        adjacency_mask = (link_relationship > threshold).int()\n        \n        # \u521b\u5efa\u8282\u70b9\u7d22\u5f15\u6620\u5c04\n        num_original_nodes = original_features.shape&#91;0]\n        num_synthetic_nodes = synthetic_features.shape&#91;0]\n        \n        # \u6784\u5efa\u65b0\u7684\u8fb9\u7d22\u5f15\n        synthetic_edges = &#91;]\n        nonzero_positions = torch.nonzero(adjacency_mask)\n        \n        for pos in nonzero_positions:\n            synthetic_node_id = pos&#91;0].item() + num_original_nodes # synthetic_node_id = 0(\u5408\u6210\u8282\u70b9\u4e2d\u7684\u7d22\u5f15) + 3(\u539f\u59cb\u5c11\u6570\u7c7b\u4e2a\u6570) = 3(\u5373\u5408\u6210\u8282\u70b9\u52a0\u5165\u539f\u59cb\u5c11\u6570\u7c7b\u540e\u7684\u5168\u5c40\u7d22\u5f15)\n            original_node_id = pos&#91;1].item() # original_node_id = 1\n            synthetic_edges.append(&#91;synthetic_node_id, original_node_id])\n            synthetic_edges.append(&#91;original_node_id, synthetic_node_id])  # \u53cc\u5411\u8fde\u63a5 \u8fd9\u4e24\u6b65\u6dfb\u52a0\u8fb9\uff1a&#91;3, 1] \u548c &#91;1, 3]\n        \n        if synthetic_edges:\n            synthetic_edge_index = torch.tensor(synthetic_edges, device=device).t()\n            augmented_edge_index = torch.cat((original_edge_index, synthetic_edge_index), dim=1)\n        else:\n            augmented_edge_index = original_edge_index\n        \n        return augmented_features, augmented_edge_index\n    \n    def create_balanced_labels(self, original_labels, num_synthetic_nodes):\n        \"\"\"\u521b\u5efa\u5e73\u8861\u540e\u7684\u6807\u7b7e\"\"\"\n        synthetic_labels = torch.ones((1, num_synthetic_nodes), dtype=original_labels.dtype, device=device)\n        balanced_labels = torch.cat((original_labels, synthetic_labels), dim=1)\n        return balanced_labels.flatten()\n\n\nclass GANTrainer:\n    \"\"\"GAN\u8bad\u7ec3\u5668 - \u73b0\u5728\u5b9e\u73b0\u771f\u6b63\u7684\u5bf9\u6297\u8bad\u7ec3\"\"\"\n    \n    def __init__(self, config, discriminator, discriminator_optimizer, \n                 generator, generator_optimizer, augmentor):\n        self.config = config\n        self.discriminator = discriminator\n        self.discriminator_optimizer = discriminator_optimizer\n        self.generator = generator\n        self.generator_optimizer = generator_optimizer\n        self.augmentor = augmentor\n        self.training_history = &#91;]\n \n    def create_data_splits(self, total_positive_samples, total_negative_samples, num_synthetic_nodes):\n        \"\"\"\u521b\u5efa\u8bad\u7ec3\u3001\u9a8c\u8bc1\u3001\u6d4b\u8bd5\u6570\u636e\u5206\u5272\"\"\"\n        # \u6b63\u6837\u672c\u5206\u5272\n        positive_indices = torch.arange(total_positive_samples)\n        train_size_pos = int(len(positive_indices) * self.config.train_ratio)\n        val_size_pos = int(len(positive_indices) * self.config.val_ratio)\n        \n        permuted_pos = positive_indices&#91;torch.randperm(len(positive_indices))]\n        idx_train_pos = permuted_pos&#91;:train_size_pos]\n        idx_val_pos = permuted_pos&#91;train_size_pos:train_size_pos + val_size_pos]\n        idx_test_pos = permuted_pos&#91;train_size_pos + val_size_pos:]\n        \n        # \u8d1f\u6837\u672c\u5206\u5272\n        negative_indices = torch.arange(total_negative_samples)\n        train_size_neg = int(len(negative_indices) * self.config.train_ratio)\n        val_size_neg = int(len(negative_indices) * self.config.val_ratio)\n        \n        permuted_neg = negative_indices&#91;torch.randperm(len(negative_indices))]\n        idx_train_neg = permuted_neg&#91;:train_size_neg]\n        idx_val_neg = permuted_neg&#91;train_size_neg:train_size_neg + val_size_neg]\n        idx_test_neg = permuted_neg&#91;train_size_neg + val_size_neg:]\n        \n        # \u5408\u5e76\u7d22\u5f15\n        num_real_samples = train_size_pos + train_size_neg\n        synthetic_indices = torch.arange(total_positive_samples + total_negative_samples, \n                                       total_positive_samples + total_negative_samples + num_synthetic_nodes)\n        \n        idx_train = torch.cat((idx_train_pos, idx_train_neg, synthetic_indices))\n        idx_val = torch.cat((idx_val_pos, idx_val_neg))\n        idx_test = torch.cat((idx_test_pos, idx_test_neg))\n        print(f\"idx_train, idx_val, idx_test, num_real_samples\u5206\u522b\u4e3a:{idx_train, idx_val, idx_test, num_real_samples}\")\n        return idx_train, idx_val, idx_test, num_real_samples\n\n    def train_generator_step(self, original_graph, minority_features, gap_size, minority_indices):\n        \"\"\"\u751f\u6210\u5668\u8bad\u7ec3\u6b65\u9aa4\uff1a\u5b66\u4e60\u6b3a\u9a97\u5224\u522b\u5668\"\"\"\n        self.generator.train()\n        self.discriminator.eval()  # \u56fa\u5b9aBN\/Dropout\u5c42\n\n        # 1. \u6e05\u7a7a\u751f\u6210\u5668\u68af\u5ea6\uff08\u5224\u522b\u5668\u68af\u5ea6\u4e0d\u6e05\u7a7a\uff0c\u4f1a\u88abbackward\u8ba1\u7b97\u4f46\u4e0d\u4f1a\u88ab\u5e94\u7528\uff09\n        self.generator_optimizer.zero_grad()\n\n        # 2. \u751f\u6210\u5408\u6210\u8282\u70b9\n        noise = torch.randn(gap_size, self.config.generator_input_dim, device=device)\n        adjacency_matrix = self.generator(noise)\n        adjacency_matrix = (adjacency_matrix + 1) \/ 2\n        \n        link_relationship = F.softmax(adjacency_matrix&#91;:, :minority_features.shape&#91;0]], dim=1)\n        synthetic_features = torch.mm(link_relationship, minority_features)\n\n        # 3. \u6784\u5efa\u589e\u5f3a\u56fe\n        augmented_features, augmented_edge_index = self.augmentor.build_augmented_graph(\n            original_graph, synthetic_features, link_relationship\n        )\n\n        # 4. \u6784\u5efa\u7a00\u758f\u90bb\u63a5\u77e9\u9635\n        sparse_adj = torch.sparse_coo_tensor(\n            augmented_edge_index,\n            torch.ones(augmented_edge_index.shape&#91;1]),\n            size=(augmented_features.shape&#91;0], augmented_features.shape&#91;0]),\n            dtype=torch.float32,\n            device=device\n        )\n\n        # 5. \u5224\u522b\u5668\u9884\u6d4b\uff08\u4e0d\u5305\u88f9no_grad\uff0c\u8ba9\u68af\u5ea6\u6d41\u56de\u751f\u6210\u5668\uff09\n        _, output_gen, _ = self.discriminator(augmented_features, sparse_adj)\n        \n        # 6. \u751f\u6210\u5668\u635f\u5931\uff1a\u8ba9\u5224\u522b\u5668\u5c06\u5408\u6210\u8282\u70b9\u8bef\u5224\u4e3a\u771f\u5b9e\uff08\u6807\u7b7e0\uff09\n        num_original = original_graph.x.shape&#91;0]\n        synthetic_start = num_original\n        synthetic_end = num_original + gap_size\n        \n        generator_loss = F.nll_loss(\n            output_gen&#91;synthetic_start:synthetic_end],\n            torch.zeros(gap_size, dtype=torch.long, device=device)\n        )\n        \n        # 7. \u53cd\u5411\u4f20\u64ad\uff08\u68af\u5ea6\u4f1a\u6d41\u5411\u751f\u6210\u5668\u548c\u5224\u522b\u5668\u53c2\u6570\uff09\n        generator_loss.backward()\n        \n        # 8. \u66f4\u65b0\u751f\u6210\u5668\u53c2\u6570\uff08\u5224\u522b\u5668\u53c2\u6570\u672a\u66f4\u65b0\uff09\n        self.generator_optimizer.step()\n        \n        return generator_loss.item()\n    \n    def train_discriminator_step(self, original_graph, minority_indices, gap_size, \n                                idx_train, idx_val, idx_test):\n        \"\"\"\u5224\u522b\u5668\u8bad\u7ec3\u6b65\u9aa4\uff1a\u533a\u5206\u771f\u5b9e\u548c\u5408\u6210\u8282\u70b9\"\"\"\n        self.discriminator.train()\n        self.generator.eval()  # \u56fa\u5b9a\u751f\u6210\u5668\n\n        # 1. \u6e05\u7a7a\u5224\u522b\u5668\u68af\u5ea6\n        self.discriminator_optimizer.zero_grad()\n\n        # 2. \u91cd\u65b0\u751f\u6210\u5408\u6210\u8282\u70b9\uff08\u4f7f\u7528\u8bc4\u4f30\u6a21\u5f0f\u7684\u751f\u6210\u5668\uff09\n        with torch.no_grad():\n            minority_features = original_graph.x&#91;minority_indices]\n            noise = torch.randn(gap_size, self.config.generator_input_dim, device=device)\n            adjacency_matrix = self.generator(noise)\n            adjacency_matrix = (adjacency_matrix + 1) \/ 2\n            \n            link_relationship = F.softmax(adjacency_matrix&#91;:, :minority_features.shape&#91;0]], dim=1)\n            synthetic_features = torch.mm(link_relationship, minority_features)\n            \n            # \u6784\u5efa\u589e\u5f3a\u56fe\u548c\u6807\u7b7e\n            augmented_features, augmented_edge_index = self.augmentor.build_augmented_graph(\n                original_graph, synthetic_features, link_relationship\n            )\n            balanced_labels = self.augmentor.create_balanced_labels(original_graph.y, gap_size)\n\n        # 3. \u6784\u5efa\u7a00\u758f\u90bb\u63a5\u77e9\u9635\n        sparse_adj = torch.sparse_coo_tensor(\n            augmented_edge_index,\n            torch.ones(augmented_edge_index.shape&#91;1]),\n            size=(augmented_features.shape&#91;0], augmented_features.shape&#91;0]),\n            dtype=torch.float32,\n            device=device\n        )\n\n        # 4. \u5224\u522b\u5668\u524d\u5411\u4f20\u64ad\n        output_real, output_gen, output_auc = self.discriminator(augmented_features, sparse_adj)\n        \n        # 5. \u6784\u5efa\u771f\u5b9e\u6807\u7b7e\uff08\u771f\u5b9e=0\uff0c\u5408\u6210=1\uff09\n        num_real_samples = original_graph.x.shape&#91;0]\n        real_labels = torch.cat(&#91;\n            torch.zeros(num_real_samples, dtype=torch.long, device=device),\n            torch.ones(gap_size, dtype=torch.long, device=device)\n        ])\n\n        # 6. \u8ba1\u7b97\u5224\u522b\u5668\u635f\u5931\n        distance_loss = -euclidean_dist(\n            original_graph.x&#91;minority_indices], \n            original_graph.x&#91;minority_indices]\n        ).mean()  # \u6ce8\u610f\uff1a\u8fd9\u4e2a\u635f\u5931\u51fd\u6570\u53ef\u80fd\u9700\u8981\u8c03\u6574\n        \n        total_loss = (\n            F.nll_loss(output_real&#91;idx_train&#91;:num_real_samples]], \n                      balanced_labels&#91;idx_train&#91;:num_real_samples]]) +\n            F.nll_loss(output_gen&#91;idx_train], real_labels&#91;idx_train]) +\n            distance_loss\n        )\n        \n        # 7. \u53cd\u5411\u4f20\u64ad\u548c\u4f18\u5316\n        total_loss.backward()\n        self.discriminator_optimizer.step()\n        \n        # 8. \u9a8c\u8bc1\n        self.discriminator.eval()\n        with torch.no_grad():\n            output_real, output_gen, output_auc = self.discriminator(augmented_features, sparse_adj)\n            \n            recall_val, f1_val, auc_val, acc_val, pre_val = accuracy(\n                output_real&#91;idx_val], balanced_labels&#91;idx_val], output_auc&#91;idx_val]\n            )\n            recall_train, f1_train, auc_train, acc_train, pre_train = accuracy(\n                output_real&#91;idx_train&#91;:num_real_samples]], \n                balanced_labels&#91;idx_train&#91;:num_real_samples]], \n                output_auc&#91;idx_train&#91;:num_real_samples]]\n            )\n        \n        return {\n            'discriminator_loss': total_loss.item(),\n            'train_recall': recall_train, 'train_f1': f1_train, 'train_acc': acc_train,\n            'val_recall': recall_val, 'val_f1': f1_val, 'val_acc': acc_val,\n        }\n    \n    def train_epoch(self, original_graph, minority_indices, gap_size, \n                   idx_train, idx_val, idx_test):\n        \"\"\"\u5b8c\u6574\u8bad\u7ec3epoch\uff1a\u5148\u8bad\u7ec3\u751f\u6210\u5668\uff0c\u518d\u8bad\u7ec3\u5224\u522b\u5668\"\"\"\n        minority_features = original_graph.x&#91;minority_indices]\n        \n        # 1. \u8bad\u7ec3\u751f\u6210\u5668\uff08\u8ba9\u5224\u522b\u5668\u8bef\u5224\uff09\n        gen_loss = self.train_generator_step(original_graph, minority_features, gap_size, minority_indices)\n        \n        # 2. \u8bad\u7ec3\u5224\u522b\u5668\uff08\u6b63\u786e\u533a\u5206\uff09\n        disc_metrics = self.train_discriminator_step(\n            original_graph, minority_indices, gap_size, idx_train, idx_val, idx_test\n        )\n        \n        # \u5408\u5e76\u7ed3\u679c\n        disc_metrics&#91;'generator_loss'] = gen_loss\n        return disc_metrics\n    \n    \nclass GraphBalanceProcessor:\n    \"\"\"\u4e3b\u5904\u7406\u5668 - \u534f\u8c03\u6574\u4e2a\u56fe\u6570\u636e\u5e73\u8861\u6d41\u7a0b\"\"\"\n    \n    def __init__(self, config=None):\n        self.config = config or GraphBalanceConfig()\n        self.data_analyzer = GraphDataAnalyzer(self.config)\n        self.model_builder = GANModelBuilder(self.config)\n        self.augmentor = None\n        self.trainer = None\n        self.processed_graphs = &#91;]\n        \n    def process_single_graph(self, graph_data, graph_index):\n        \"\"\"\u5904\u7406\u5355\u4e2a\u56fe\u6570\u636e\"\"\"\n        print(f\"\\n\u5904\u7406\u56fe {graph_index}:\")\n        \n        # \u5206\u6790\u56fe\u7684\u5e73\u8861\u60c5\u51b5\n        analysis = self.data_analyzer.analyze_graph_balance(graph_data)\n        gap_size = analysis&#91;'imbalance_gap']\n        if gap_size == 0:\n            print(f\"  \u8df3\u8fc7: \u5df2\u7ecf\u5e73\u8861\")\n            return graph_data\n        \n        print(f\"  \u9700\u8981\u751f\u6210 {gap_size} \u4e2a\u6b63\u7c7b\u6837\u672c\")\n        \n        # \u51c6\u5907\u6570\u636e\n        features = graph_data.x\n        labels = graph_data.y\n        minority_indices = torch.nonzero(labels == 1)&#91;:, 0]\n        minority_features = features&#91;minority_indices]\n\n        # \u6784\u5efa\u6a21\u578b\n        generator, generator_optimizer = self.model_builder.build_generator(\n            output_dim=minority_features.shape&#91;0]  # \u6ce8\u610f\uff1a\u8fd9\u91cc\u5e94\u8be5\u6839\u636e\u5c11\u6570\u7c7b\u7279\u5f81\u7ef4\u5ea6\u8c03\u6574\n        )\n        discriminator, discriminator_optimizer = self.model_builder.build_discriminator(\n            input_dim=features.shape&#91;1]\n        )\n        \n        # \u521b\u5efa\u589e\u5f3a\u5668\n        self.augmentor = GraphAugmentor(self.config, generator, generator_optimizer)\n        \n        # ***\u5173\u952e\u4fee\u6539\uff1a\u521b\u5efa\u8bad\u7ec3\u5668\u65f6\u4f20\u5165\u6240\u6709\u7ec4\u4ef6***\n        self.trainer = GANTrainer(\n            self.config,\n            discriminator,\n            discriminator_optimizer,\n            generator,  # \u65b0\u589e\uff1a\u4f20\u5165\u751f\u6210\u5668\n            generator_optimizer,  # \u65b0\u589e\uff1a\u4f20\u5165\u751f\u6210\u5668\u4f18\u5316\u5668\n            self.augmentor  # \u65b0\u589e\uff1a\u4f20\u5165\u589e\u5f3a\u5668\n        )\n        \n        # \u521b\u5efa\u6570\u636e\u5206\u5272\n        idx_train, idx_val, idx_test, num_real_samples = self.trainer.create_data_splits(\n            analysis&#91;'num_positives'], analysis&#91;'num_negatives'], gap_size\n        )\n        \n        # \u8bad\u7ec3GAN\uff08\u73b0\u5728\u751f\u6210\u5668\u548c\u5224\u522b\u5668\u90fd\u4f1a\u66f4\u65b0\uff09\n        print(f\"  \u5f00\u59cb\u5bf9\u6297\u8bad\u7ec3...\")\n        best_metrics = None\n        best_score = 0\n        \n        for epoch in range(self.config.num_epochs):\n            metrics = self.trainer.train_epoch(\n                graph_data, minority_indices, gap_size, idx_train, idx_val, idx_test\n            )\n            \n            if (epoch+1) % 10 == 0:\n                print(f\"  Epoch {epoch+1}\/{self.config.num_epochs} | \"\n                      f\"Gen Loss: {metrics&#91;'generator_loss']:.4f} | \"\n                      f\"Disc Loss: {metrics&#91;'discriminator_loss']:.4f}\")\n            \n            # \u8ba1\u7b97\u7efc\u5408\u8bc4\u5206\n            current_score = (metrics&#91;'val_recall'] + metrics&#91;'val_acc']) \/ 2\n            if current_score > best_score:\n                best_score = current_score\n                best_metrics = metrics\n        \n        if best_metrics:\n            print(f\"  \u6700\u4f73\u9a8c\u8bc1\u6027\u80fd: Recall={best_metrics&#91;'val_recall']:.4f}, \"\n                  f\"F1={best_metrics&#91;'val_f1']:.4f}, Acc={best_metrics&#91;'val_acc']:.4f}\")\n        \n        # \u4f7f\u7528\u6700\u7ec8\u751f\u6210\u5668\u751f\u6210\u5e73\u8861\u540e\u7684\u56fe\n        with torch.no_grad():\n            synthetic_features, _, link_relationship = self.augmentor.generate_synthetic_nodes(\n                gap_size, graph_data.x&#91;minority_indices]\n            )\n            augmented_features, augmented_edge_index = self.augmentor.build_augmented_graph(\n                graph_data, synthetic_features, link_relationship\n            )\n            balanced_labels = self.augmentor.create_balanced_labels(labels, gap_size)\n        \n        balanced_graph = Data(\n            x=augmented_features,\n            edge_index=augmented_edge_index,\n            y=balanced_labels.view(1, -1)\n        )\n        print(f\"  \u5e73\u8861\u540e\u7684\u56fe\uff1a{balanced_graph}\")\n        return balanced_graph\n    \n    def process_all_graphs(self):\n        \"\"\"\u5904\u7406\u6240\u6709\u56fe\u6570\u636e\"\"\"\n        print(\"=\" * 80)\n        print(\"\u5f00\u59cb\u5904\u7406\u56fe\u6570\u636e\u5e73\u8861\u4efb\u52a1\")\n        print(\"=\" * 80)\n        \n        # \u52a0\u8f7d\u6570\u636e\n        if not self.data_analyzer.load_data(self.config.input_path):\n            return False\n        \n        # \u5206\u5272\u6570\u636e\n        well_balanced, sparse_positive = self.data_analyzer.split_data_by_balance()\n        train_data, val_data = self.data_analyzer.create_train_test_split(well_balanced, sparse_positive)\n        \n        # \u4fdd\u5b58\u9a8c\u8bc1\u6570\u636e\n        # torch.save(val_data, self.config.val_output_path)\n        print(f\"\u9a8c\u8bc1\u6570\u636e\u5df2\u4fdd\u5b58\u5230: {self.config.val_output_path}\")\n        \n        # \u5904\u7406\u8bad\u7ec3\u6570\u636e\n        processed_train_data = &#91;]\n        \n        for i, graph in enumerate(train_data):\n            processed_graph = self.process_single_graph(graph, i)\n            if processed_graph is not None:\n                processed_train_data.append(processed_graph)\n            break\n        \n        # \u4fdd\u5b58\u5904\u7406\u540e\u7684\u8bad\u7ec3\u6570\u636e\n        # torch.save(processed_train_data, self.config.train_output_path)\n        print(f\"\\n\u8bad\u7ec3\u6570\u636e\u5df2\u4fdd\u5b58\u5230: {self.config.train_output_path}\")\n        print(f\"\u6210\u529f\u5904\u7406 {len(processed_train_data)} \u4e2a\u56fe\")\n        \n        return True\n    \nif __name__ == \"__main__\":\n    \"\"\"\u4e3b\u51fd\u6570\"\"\"\n    # \u521b\u5efa\u914d\u7f6e\n    config = GraphBalanceConfig()\n    \n    # \u521b\u5efa\u5904\u7406\u5668\n    processor = GraphBalanceProcessor(config)\n    \n    # \u6267\u884c\u5904\u7406\n    success = processor.process_all_graphs()\n    \n    if success:\n        print(\"\\n\" + \"=\" * 80)\n        print(\"\u3010success\u3011---\u56fe\u6570\u636e\u5e73\u8861\u5904\u7406\u5b8c\u6210\uff01\")\n        print(\"=\" * 80)\n    else:\n        print(\"\\n\u51fa\u4e86\u70b9\u95ee\u9898\u5462\uff01---\u5904\u7406\u5931\u8d25\uff0c\u8bf7\u68c0\u67e5\u9519\u8bef\u4fe1\u606f\uff0c\u8bf7\u5927\u4fa0\u4ece\u5934\u518d\u6765\uff01\")\n<\/code><\/pre>\n","protected":false},"excerpt":{"rendered":"","protected":false},"author":1,"featured_media":0,"comment_status":"open","ping_status":"open","sticky":false,"template":"","format":"standard","meta":{"footnotes":""},"categories":[25],"tags":[28,26,27],"class_list":["post-172","post","type-post","status-publish","format-standard","hentry","category-25","tag-gan","tag-26","tag-27"],"_links":{"self":[{"href":"https:\/\/snakesleep.work\/index.php?rest_route=\/wp\/v2\/posts\/172","targetHints":{"allow":["GET"]}}],"collection":[{"href":"https:\/\/snakesleep.work\/index.php?rest_route=\/wp\/v2\/posts"}],"about":[{"href":"https:\/\/snakesleep.work\/index.php?rest_route=\/wp\/v2\/types\/post"}],"author":[{"embeddable":true,"href":"https:\/\/snakesleep.work\/index.php?rest_route=\/wp\/v2\/users\/1"}],"replies":[{"embeddable":true,"href":"https:\/\/snakesleep.work\/index.php?rest_route=%2Fwp%2Fv2%2Fcomments&post=172"}],"version-history":[{"count":1,"href":"https:\/\/snakesleep.work\/index.php?rest_route=\/wp\/v2\/posts\/172\/revisions"}],"predecessor-version":[{"id":173,"href":"https:\/\/snakesleep.work\/index.php?rest_route=\/wp\/v2\/posts\/172\/revisions\/173"}],"wp:attachment":[{"href":"https:\/\/snakesleep.work\/index.php?rest_route=%2Fwp%2Fv2%2Fmedia&parent=172"}],"wp:term":[{"taxonomy":"category","embeddable":true,"href":"https:\/\/snakesleep.work\/index.php?rest_route=%2Fwp%2Fv2%2Fcategories&post=172"},{"taxonomy":"post_tag","embeddable":true,"href":"https:\/\/snakesleep.work\/index.php?rest_route=%2Fwp%2Fv2%2Ftags&post=172"}],"curies":[{"name":"wp","href":"https:\/\/api.w.org\/{rel}","templated":true}]}}