{"id":174,"date":"2025-12-08T23:46:02","date_gmt":"2025-12-08T15:46:02","guid":{"rendered":"https:\/\/snakesleep.work\/?p=174"},"modified":"2025-12-08T23:46:03","modified_gmt":"2025-12-08T15:46:03","slug":"get_balance_graph_data-py%e5%be%85%e5%ae%8c%e6%88%90","status":"publish","type":"post","link":"https:\/\/snakesleep.work\/?p=174","title":{"rendered":"get_balance_graph_data.py(\u5f85\u5b8c\u6210)"},"content":{"rendered":"\n<pre class=\"wp-block-code\"><code>import torch\nimport numpy as np\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch_geometric.data import Data\nfrom torch_geometric.nn import GCNConv, global_mean_pool\nfrom torch_geometric.utils import k_hop_subgraph, is_undirected, to_undirected\nfrom typing import List, Tuple\nimport evaluate\nimport warnings\nwarnings.filterwarnings('ignore')\n\n# ==================== \u8d85\u53c2\u6570\u914d\u7f6e\u533a ====================\nDEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'\nLATENT_DIM = 64  # \u751f\u6210\u5668\u5185\u90e8\u9690\u7a7a\u95f4\u7ef4\u5ea6\nHIDDEN_DIM = 256  # \u56fe\u795e\u7ecf\u7f51\u7edc\u9690\u85cf\u5c42\u7ef4\u5ea6\nK_CANDIDATES = 5  # \u5019\u9009\u90bb\u5c45\u6570\u91cf\uff08\u53ef\u8c03\u8d85\u53c2\u6570\uff09\nNUM_EPOCHS_PER_GRAPH = 50  # \u6bcf\u4e2a\u56fe\u7684\u8bad\u7ec3\u8f6e\u6570\nLR_GAN = 1e-3  # \u751f\u6210\u5668\u548c\u5224\u522b\u5668\u5b66\u4e60\u7387\nEDGE_THRESHOLD = 0.3  # \u751f\u6210\u8fb9\u6982\u7387\u9608\u503c\nSEED = 42\ntorch.manual_seed(SEED)\nnp.random.seed(SEED)\n\n\n# ==================== 1. Geo-ImGAGN \u6a21\u578b\u5b9a\u4e49 ====================\n\nclass GeoGraphGenerator(nn.Module):\n    \"\"\"\n    \u751f\u6210\u5668\uff1a\u8f93\u5165\u5c11\u6570\u7c7b\u4e2d\u5fc3\u8282\u70b9\uff0c\u8f93\u51fa\u5408\u6210\u8282\u70b9\u7279\u5f81 + \u5230\u5019\u9009\u90bb\u5c45\u7684\u8fb9\u6743\u91cd\n    \u6838\u5fc3\uff1a\u5b66\u4e60\u53e3\u888b\u5fae\u73af\u5883\u7684\u62d3\u6251-\u7279\u5f81\u534f\u540c\u6a21\u5f0f\n    \"\"\"\n    def __init__(self, node_dim: int, hidden_dim: int = HIDDEN_DIM, k_candidates: int = K_CANDIDATES):\n        super().__init__()\n        # \u56fe\u7f16\u7801\u5668\uff1a\u805a\u5408\u4e2d\u5fc3\u53ca\u5176k\u9636\u90bb\u5c45\u4fe1\u606f\n        self.conv1 = GCNConv(node_dim, hidden_dim)\n        self.conv2 = GCNConv(hidden_dim, hidden_dim)\n        \n        # \u89e3\u7801\u5668\uff1a\u5408\u6210\u8282\u70b9\u7279\u5f81\uff08ESM2\u98ce\u683c\uff09\n        self.feat_decoder = nn.Sequential(\n            nn.Linear(hidden_dim * 2, hidden_dim),\n            nn.ReLU(),\n            nn.Dropout(0.1),\n            nn.Linear(hidden_dim, hidden_dim),\n            nn.ReLU(),\n            nn.Linear(hidden_dim, node_dim)\n        )\n        \n        # \u8fb9\u6743\u91cd\u9884\u6d4b\uff1a\u6a21\u62df\u7a7a\u95f4\u8ddd\u79bb\uff08\u6982\u7387\u8d8a\u9ad8=\u7a7a\u95f4\u8d8a\u8fd1\uff09\n        self.edge_decoder = nn.Sequential(\n            nn.Linear(hidden_dim * 2, hidden_dim),\n            nn.ReLU(),\n            nn.Linear(hidden_dim, k_candidates),\n            nn.Sigmoid()  # \u8f93\u51fa&#91;0,1]\u6982\u7387\n        )\n        \n        self.k_candidates = k_candidates\n\n    def forward(self, x: torch.Tensor, edge_index: torch.Tensor, center_idx: int) -> Tuple&#91;\n        torch.Tensor, torch.Tensor, torch.Tensor]:\n        \"\"\"\n        \u524d\u5411\u4f20\u64ad\n        :param x: \u8282\u70b9\u7279\u5f81 &#91;N, node_dim]\n        :param edge_index: \u8fb9\u7d22\u5f15 &#91;2, E]\n        :param center_idx: \u4e2d\u5fc3\u8282\u70b9\u539f\u59cb\u7d22\u5f15\n        :return: (\u5408\u6210\u8282\u70b9\u7279\u5f81&#91;1, node_dim], \u8fb9\u6743\u91cd&#91;1, k], \u5019\u9009\u90bb\u5c45\u539f\u59cb\u7d22\u5f15&#91;k])\n        \"\"\"\n        # 1. \u63d0\u53d6\u4e2d\u5fc3\u8282\u70b9\u76841\u9636\u5b50\u56fe\uff08\u7269\u7406\u7a7a\u95f4\u7ea6\u675f\uff09\n        subset, sub_edge, center_map, _ = k_hop_subgraph(\n            node_idx=center_idx,\n            num_hops=1,\n            edge_index=edge_index,\n            num_nodes=x.size(0),\n            relabel_nodes=True\n        )\n        \n        # 2. \u5982\u679c\u5b50\u56fe\u592a\u5c0f\uff0c\u7528\u4e2d\u5fc3\u8282\u70b9\u586b\u5145\u9632\u6b62\u5d29\u6e83\n        if subset.size(0) &lt; 2:\n            subset = torch.tensor(&#91;center_idx, center_idx], device=x.device)\n            sub_edge = torch.tensor(&#91;&#91;0, 1], &#91;1, 0]], device=x.device, dtype=torch.long)\n            center_map = 0\n        \n        # 3. \u56fe\u5377\u79ef\u7f16\u7801\u7ed3\u6784\u4fe1\u606f\n        sub_x = x&#91;subset]\n        h = F.relu(self.conv1(sub_x, sub_edge))\n        h = F.relu(self.conv2(h, sub_edge))\n        \n        # 4. \u4e2d\u5fc3\u8282\u70b9\u7279\u5f81 + \u5b50\u56fe\u5168\u5c40 pooling\n        h_center = h&#91;center_map].squeeze(0)  # &#91;hidden]\n        h_pool = global_mean_pool(h, torch.zeros(h.size(0), dtype=torch.long, device=h.device)).squeeze(0)  # &#91;hidden]\n        \n        # 5. \u89e3\u7801\u751f\u6210\n        h_cat = torch.cat(&#91;h_center.unsqueeze(0), h_pool.unsqueeze(0)], dim=1)  # &#91;1, hidden*2]\n        new_x = self.feat_decoder(h_cat)  # &#91;1, node_dim]\n\n        # ===== \u5f3a\u5236\u591a\u6837\u5316\uff1a\u9ad8\u65af\u566a\u58f0 + Dropout =====\n        new_x = new_x + torch.randn_like(new_x, device=new_x.device) * 0.1  # \u9ad8\u65af\u566a\u58f0\n        new_x = F.dropout(new_x, p=0.2, training=self.training)            # dropout\n\n        edge_probs = self.edge_decoder(h_cat)  # &#91;1, k_candidates]\n        \n        # 6. \u63d0\u53d6\u5019\u9009\u90bb\u5c45\u539f\u59cb\u7d22\u5f15\uff08\u6392\u9664\u4e2d\u5fc3\u8282\u70b9\uff09\n        candidate_neighbors = subset&#91;subset != center_idx]\n        \n        # \u5982\u679c\u4e0d\u8db3k\u4e2a\uff0c\u91cd\u590d\u586b\u5145\n        if len(candidate_neighbors) &lt; self.k_candidates:\n            repeats = (self.k_candidates \/\/ len(candidate_neighbors)) + 1\n            candidate_neighbors = candidate_neighbors.repeat(repeats)&#91;:self.k_candidates]\n        \n        return new_x, edge_probs, candidate_neighbors&#91;:self.k_candidates]\n\n\nclass GeoDiscriminator(nn.Module):\n    \"\"\"\n    \u5224\u522b\u5668\uff1a\u5224\u65ad\u5b50\u56fe-\u8282\u70b9\u5bf9\u662f\u5426\u6765\u81ea\u771f\u5b9e\u6570\u636e\n    \u8f93\u5165\uff1a\u5b50\u56fe\u7279\u5f81 + \u662f\u5426\u4e3a\u751f\u6210\u8282\u70b9\u7684\u6807\u8bb0\n    \"\"\"\n    def __init__(self, node_dim: int, hidden_dim: int = HIDDEN_DIM):\n        super().__init__()\n        # \u8f93\u5165=\u7279\u5f81 + \u751f\u6210\u6807\u8bb0(1\u7ef4)\n        self.conv = GCNConv(node_dim + 1, hidden_dim)\n        self.mlp = nn.Sequential(\n            nn.Linear(hidden_dim, hidden_dim),\n            nn.ReLU(),\n            nn.Dropout(0.1),\n            nn.Linear(hidden_dim, 1),\n            nn.Sigmoid()\n        )\n\n    def forward(self, sub_x: torch.Tensor, sub_edge: torch.Tensor, is_gen_flag: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        :param sub_x: \u5b50\u56fe\u8282\u70b9\u7279\u5f81 &#91;N_sub, node_dim]\n        :param sub_edge: \u5b50\u56fe\u8fb9\u7d22\u5f15 &#91;2, E_sub]\n        :param is_gen_flag: \u751f\u6210\u6807\u8bb0 &#91;N_sub, 1], 1=\u751f\u6210\u8282\u70b9\/\u5c11\u6570\u7c7b\u4e2d\u5fc3, 0=\u771f\u5b9e\u666e\u901a\u8282\u70b9\n        :return: \u5b50\u56fe\u4e3a\u771f\u7684\u6982\u7387 &#91;1]\n        \"\"\"\n        h = torch.cat(&#91;sub_x, is_gen_flag], dim=1)\n        h = F.relu(self.conv(h, sub_edge))\n        graph_rep = global_mean_pool(h, torch.zeros(h.size(0), dtype=torch.long, device=h.device))\n        return self.mlp(graph_rep)\n\n\n# ==================== 2. \u8bad\u7ec3\u51fd\u6570 ====================\n\ndef train_generator_on_graph(graph: Data, generator: GeoGraphGenerator, \n                             discriminator: GeoDiscriminator, device: str) -> GeoGraphGenerator:\n    \"\"\"\n    \u5728\u5355\u4e2a\u86cb\u767d\u8d28\u56fe\u5185\u90e8\u8bad\u7ec3\u751f\u6210\u5668\u548c\u5224\u522b\u5668\n    :param graph: \u5355\u4e2a\u86cb\u767d\u8d28\u56fe\n    :param generator: \u751f\u6210\u5668\u5b9e\u4f8b\n    :param discriminator: \u5224\u522b\u5668\u5b9e\u4f8b\n    :param device: \u8ba1\u7b97\u8bbe\u5907\n    :return: \u8bad\u7ec3\u540e\u7684\u751f\u6210\u5668\n    \"\"\"\n    generator.train()\n    discriminator.train()\n    \n    optG = torch.optim.Adam(generator.parameters(), lr=LR_GAN)\n    optD = torch.optim.Adam(discriminator.parameters(), lr=LR_GAN)\n    bce_loss = nn.BCELoss()\n    \n    # \u83b7\u53d6\u6b63\u7c7b\u8282\u70b9\u7d22\u5f15\n    pos_idx = (graph.y == 1).nonzero(as_tuple=True)&#91;0]\n    if len(pos_idx) == 0:\n        return generator  # \u6ca1\u6709\u6b63\u7c7b\uff0c\u65e0\u6cd5\u8bad\u7ec3\n    \n    # \u6bcf\u4e2aepoch\u968f\u673a\u91c7\u6837\u591a\u4e2a\u4e2d\u5fc3\u8282\u70b9\uff08\u589e\u52a0\u591a\u6837\u6027\uff09\n    for epoch in range(NUM_EPOCHS_PER_GRAPH):\n        epoch_loss_d, epoch_loss_g = 0.0, 0.0\n        \n        # \u6bcf\u4e2aepoch\u8bad\u7ec3batch_size\u4e2a\u4e2d\u5fc3\u8282\u70b9\n        batch_size = min(8, len(pos_idx))\n        centers = pos_idx&#91;torch.randperm(len(pos_idx))&#91;:batch_size]]\n        \n        for center_idx in centers:\n            # ---------------- \u5224\u522b\u5668\u8bad\u7ec3 ----------------\n            optD.zero_grad()\n            \n            # \u771f\u5b9e\u5b50\u56fe\n            # real_subset, real_edge, real_center_map, _ = k_hop_subgraph(\n            #     center_idx.item(), 1, graph.edge_index, graph.x.size(0), relabel_nodes=True\n            # )\n            # \u771f\u5b9e\u5b50\u56fe\n            real_subset, real_edge, real_center_map, _ = k_hop_subgraph(\n            node_idx=center_idx.item(),\n            num_hops=1,\n            edge_index=graph.edge_index,\n            relabel_nodes=True,\n            num_nodes=graph.x.size(0)\n            )\n\n\n            real_sub_x = graph.x&#91;real_subset].to(device)\n            real_is_gen = torch.zeros(real_sub_x.size(0), 1, device=device)\n            real_is_gen&#91;real_center_map] = 1.0  # \u4e2d\u5fc3\u8282\u70b9\u6807\u8bb0\u4e3a1\n            \n            # \u5224\u522b\u5668\u6253\u5206\n            D_real = discriminator(real_sub_x, real_edge.to(device), real_is_gen)\n            loss_real = bce_loss(D_real, torch.ones_like(D_real))\n            \n            # \u5047\u5b50\u56fe\uff08\u751f\u6210\u5668\u751f\u6210\uff09\n            with torch.no_grad():\n                fake_x, fake_probs, neighbor_idx = generator(graph.x, graph.edge_index, center_idx.item())\n                fake_x = fake_x.to(device)\n            \n            # \u7b5b\u9009\u9ad8\u6982\u7387\u90bb\u5c45\n            mask = fake_probs.squeeze() > EDGE_THRESHOLD\n            selected_neighbors = neighbor_idx&#91;mask.cpu()]\n            \n            if len(selected_neighbors) == 0:  # \u6ca1\u6709\u9009\u4e2d\u4efb\u4f55\u90bb\u5c45\uff0c\u8df3\u8fc7\n                continue\n            \n            # \u6784\u5efa\u5047\u5b50\u56fe\n            fake_neighbor_x = graph.x&#91;selected_neighbors].to(device)\n            fake_sub_x = torch.cat(&#91;fake_neighbor_x, fake_x], dim=0)\n            fake_edge = torch.stack(&#91;\n                torch.arange(len(selected_neighbors), dtype=torch.long, device=device),\n                torch.full((len(selected_neighbors),), len(selected_neighbors), dtype=torch.long, device=device)\n            ], dim=0)\n            fake_is_gen = torch.zeros(fake_sub_x.size(0), 1, device=device)\n            fake_is_gen&#91;-1] = 1.0  # \u751f\u6210\u8282\u70b9\u6807\u8bb0\u4e3a1\n            \n            # \u5224\u522b\u5668\u5bf9\u5047\u5b50\u56fe\u6253\u5206\n            D_fake = discriminator(fake_sub_x, fake_edge, fake_is_gen)\n            loss_fake = bce_loss(D_fake, torch.zeros_like(D_fake))\n            \n            # \u5224\u522b\u5668\u603b\u635f\u5931\n            loss_d = loss_real + loss_fake\n            loss_d.backward()\n            optD.step()\n            epoch_loss_d += loss_d.item()\n            \n            # ---------------- \u751f\u6210\u5668\u8bad\u7ec3 ----------------\n            optG.zero_grad()\n            \n            # \u91cd\u65b0\u751f\u6210\uff08\u6b64\u65f6\u9700\u8981\u68af\u5ea6\uff09\n            fake_x, fake_probs, neighbor_idx = generator(graph.x, graph.edge_index, center_idx.item())\n            fake_x = fake_x.to(device)\n            \n            # \u91cd\u65b0\u7b5b\u9009\u90bb\u5c45\n            mask = fake_probs.squeeze() > EDGE_THRESHOLD\n            selected_neighbors = neighbor_idx&#91;mask.cpu()]\n            \n            if len(selected_neighbors) == 0:\n                continue\n            \n            # \u91cd\u65b0\u6784\u5efa\u5047\u5b50\u56fe\n            fake_neighbor_x = graph.x&#91;selected_neighbors].to(device)\n            fake_sub_x = torch.cat(&#91;fake_neighbor_x, fake_x], dim=0)\n            fake_edge = torch.stack(&#91;\n                torch.arange(len(selected_neighbors), dtype=torch.long, device=device),\n                torch.full((len(selected_neighbors),), len(selected_neighbors), dtype=torch.long, device=device)\n            ], dim=0)\n            fake_is_gen = torch.zeros(fake_sub_x.size(0), 1, device=device)\n            fake_is_gen&#91;-1] = 1.0\n            \n            # \u751f\u6210\u5668\u635f\u5931\uff1a\u5e0c\u671b\u5224\u522b\u5668\u5bf9\u5047\u5b50\u56fe\u6253\u9ad8\u5206\n            D_fake = discriminator(fake_sub_x, fake_edge, fake_is_gen)\n            loss_g = bce_loss(D_fake, torch.ones_like(D_fake))\n            \n            loss_g.backward()\n            optG.step()\n            epoch_loss_g += loss_g.item()\n        \n        # \u6bcf10\u4e2aepoch\u6253\u5370\u4e00\u6b21\n        if (epoch + 1) % 10 == 0:\n            print(f\"    Epoch {epoch+1}\/{NUM_EPOCHS_PER_GRAPH}, \"\n                  f\"Loss D: {epoch_loss_d\/batch_size:.4f}, Loss G: {epoch_loss_g\/batch_size:.4f}\")\n    \n    return generator\n\n# ==================== 3. \u5e73\u8861\u51fd\u6570 ====================\n\ndef balance_single_graph(graph: Data, generator: GeoGraphGenerator = None, \n                         device: str = DEVICE, k_candidates: int = K_CANDIDATES) -> Data:\n    \"\"\"\n    \u5bf9\u5355\u4e2a\u86cb\u767d\u8d28\u56fe\u8fdb\u884c1:1\u5e73\u8861\n    :param graph: \u539f\u59cb\u56fe\n    :param generator: \u9884\u8bad\u7ec3\u751f\u6210\u5668\uff08\u5982\u679c\u4e3aNone\u5219\u8bad\u7ec3\u65b0\u7684\uff09\n    :param device: \u8ba1\u7b97\u8bbe\u5907\n    :param k_candidates: \u5019\u9009\u90bb\u5c45\u6570\u91cf\n    :return: \u5e73\u8861\u540e\u7684\u56fe\n    \"\"\"\n    # \u786e\u4fddedge_index\u662f\u65e0\u5411\u7684\n    if not is_undirected(graph.edge_index):\n        graph.edge_index = to_undirected(graph.edge_index)\n        print(f\"\u8b66\u544a: \u8f93\u5165\u56fe\u4e0d\u662f\u65e0\u5411\u56fe\uff0c\u5df2\u81ea\u52a8\u8f6c\u6362\")\n    \n    graph = graph.to(device)\n    pos_idx = (graph.y == 1).nonzero(as_tuple=True)&#91;0]\n    neg_idx = (graph.y == 0).nonzero(as_tuple=True)&#91;0]\n    gap = len(neg_idx) - len(pos_idx)\n    \n    # \u5df2\u5e73\u8861\u6216\u8d1f\u7c7b\u66f4\u5c11\n    if gap &lt;= 0:\n        return graph.cpu()\n    \n    # \u5982\u679c\u6ca1\u6709\u63d0\u4f9b\u751f\u6210\u5668\uff0c\u5219\u8bad\u7ec3\u4e00\u4e2a\u65b0\u7684\n    if generator is None:\n        print(f\"  \u8bad\u7ec3\u65b0\u7684\u751f\u6210\u5668...\")\n        generator = GeoGraphGenerator(node_dim=graph.x.size(1), k_candidates=k_candidates).to(device)\n        discriminator = GeoDiscriminator(node_dim=graph.x.size(1)).to(device)\n        generator = train_generator_on_graph(graph, generator, discriminator, device)\n    else:\n        generator.eval()\n    \n    # \u751f\u6210\u8282\u70b9\n    new_x_list, new_edge_list = &#91;], &#91;]\n    current_num_nodes = graph.x.size(0)\n    \n    print(f\"  \u751f\u6210 {gap} \u4e2a\u8282\u70b9...\")\n    with torch.no_grad():\n        for i in range(gap):\n            # \u968f\u673a\u9009\u62e9\u6a21\u677f\u4e2d\u5fc3\n            center_idx = pos_idx&#91;torch.randint(len(pos_idx), (1,))].item()\n            \n            # \u751f\u6210\u8282\u70b9\u7279\u5f81\u548c\u8fb9\u6743\u91cd\n            fake_x, edge_probs, neighbor_idx = generator(graph.x, graph.edge_index, center_idx)\n            \n            # \u7b5b\u9009\u9ad8\u6982\u7387\u90bb\u5c45\uff08\u6a21\u62df\u7a7a\u95f4\u8ddd\u79bb\u7b5b\u9009\uff09\n            mask = edge_probs.squeeze().cpu() > EDGE_THRESHOLD\n            selected_neighbors = neighbor_idx&#91;mask.cpu()]\n            \n            if len(selected_neighbors) == 0:\n                # \u5982\u679c\u90fd\u6ca1\u9009\u4e2d\uff0c\u9ed8\u8ba4\u8fde\u7b2c\u4e00\u4e2a\u5019\u9009\u90bb\u5c45\n                selected_neighbors = neighbor_idx&#91;:1]\n            \n            # \u4fdd\u5b58\u751f\u6210\u8282\u70b9\u7279\u5f81\n            new_x_list.append(fake_x.cpu())\n            \n            # \u6784\u5efa\u65b0\u8fb9\uff08\u65e0\u5411\u56fe\u9700\u8981\u53cc\u5411\u8fde\u63a5\uff09\n            new_node_id = current_num_nodes + i\n            # \u6dfb\u52a0\u6b63\u5411\u8fb9\uff1a\u5019\u9009\u90bb\u5c45 \u2192 \u65b0\u8282\u70b9\n            new_edges_forward = torch.stack(&#91;\n                selected_neighbors,\n                torch.full_like(selected_neighbors, new_node_id, dtype=torch.long)\n            ], dim=0)\n            # \u6dfb\u52a0\u53cd\u5411\u8fb9\uff1a\u65b0\u8282\u70b9 \u2192 \u5019\u9009\u90bb\u5c45\n            new_edges_backward = torch.stack(&#91;\n                torch.full_like(selected_neighbors, new_node_id, dtype=torch.long),\n                selected_neighbors\n            ], dim=0)\n            # \u5408\u5e76\n            new_edges = torch.cat(&#91;new_edges_forward, new_edges_backward], dim=1)\n            new_edge_list.append(new_edges)\n    \n    # \u5408\u5e76\u6240\u6709\u65b0\u8282\u70b9\u548c\u8fb9\n    all_new_x = torch.cat(new_x_list, dim=0)\n    new_x = torch.cat(&#91;graph.x.cpu(), all_new_x], dim=0)\n    new_y = torch.cat(&#91;graph.y.cpu(), torch.ones(gap, dtype=torch.long)], dim=0)\n    \n    # \u5408\u5e76\u8fb9\n    if new_edge_list:\n        all_new_edges = torch.cat(new_edge_list, dim=1)\n        new_edge_index = torch.cat(&#91;graph.edge_index.cpu(), all_new_edges.cpu()], dim=1)\n        # \u53bb\u91cd\u548c\u53bb\u81ea\u73af\n        new_edge_index = to_undirected(new_edge_index)\n    else:\n        new_edge_index = graph.edge_index.cpu()\n    \n    return Data(x=new_x, edge_index=new_edge_index, y=new_y)\n\n\ndef balance_graph_list(graph_list: List&#91;Data], device: str = DEVICE, \n                       k_candidates: int = K_CANDIDATES) -> List&#91;Data]:\n    \"\"\"\n    \u6279\u91cf\u5904\u7406\u591a\u4e2a\u86cb\u767d\u8d28\u56fe\n    :param graph_list: \u539f\u59cb\u56fe\u5217\u8868\n    :param device: \u8ba1\u7b97\u8bbe\u5907\n    :param k_candidates: \u5019\u9009\u90bb\u5c45\u6570\u91cf\n    :return: \u5e73\u8861\u540e\u7684\u56fe\u5217\u8868\n    \"\"\"\n    balanced_list = &#91;]\n    print(f\"\u5f00\u59cb\u5904\u7406 {len(graph_list)} \u4e2a\u86cb\u767d\u8d28\u56fe...\")\n    \n    for idx, graph in enumerate(graph_list):\n        print(f\"\\n&#91;{idx+1}\/{len(graph_list)}] \u86cb\u767d\u56fe (\u539f\u59cb\u8282\u70b9: {graph.x.size(0)}, \u6b63\u7c7b: {graph.y.sum().item()})\")\n        balanced_graph = balance_single_graph(graph, None, device, k_candidates)\n        \n        # \u6253\u5370\u5e73\u8861\u4fe1\u606f\n        ratio_before = graph.y.float().mean().item()\n        ratio_after = balanced_graph.y.float().mean().item()\n        print(f\"  \u2514\u2500> \u5e73\u8861\u540e: \u8282\u70b9={balanced_graph.x.size(0)}, \u6b63\u7c7b={balanced_graph.y.sum().item()}, \"\n              f\"\u6b63\u7c7b\u5360\u6bd4: {ratio_before:.2f} -> {ratio_after:.2f}\")\n        \n        balanced_list.append(balanced_graph)\n        # \u5047\u8bbe\u5217\u8868\u4e2d\u53ea\u6709\u4e00\u4e2a\u56fe\uff0c\u5e73\u8861\u4e00\u4e2a\u56fe\u770b\u770b\u60c5\u51b5\u3002\n        break\n    \n    return balanced_list\n\n# ==================== 5. \u4e3b\u6267\u884c\u6d41\u7a0b ====================\n\ndef main():\n    # \u6570\u636e\u8def\u5f84\n    DATA_PATH = \"ESM\/graph_data_573.pt\"\n    OUTPUT_PATH = \"GAN\/graph_data_573_balanced.pt\"\n    \n    print(\"=\"*60)\n    print(\"Geo-ImGAGN \u86cb\u767d\u8d28\u56fe\u975e\u5e73\u8861\u5904\u7406\")\n    print(\"=\"*60)\n    print(f\"\u8d85\u53c2\u914d\u7f6e: \u5019\u9009\u90bb\u5c45={K_CANDIDATES}, \u8bad\u7ec3\u8f6e\u6570={NUM_EPOCHS_PER_GRAPH}, \u8bbe\u5907={DEVICE}\")\n    \n    # 1. \u52a0\u8f7d\u6570\u636e\n    print(\"\\n&#91;1\/4] \u52a0\u8f7d\u86cb\u767d\u8d28\u56fe\u6570\u636e...\")\n    try:\n        graph_list = torch.load(DATA_PATH, weights_only=False)\n        print(f\"  \u6210\u529f\u52a0\u8f7d {len(graph_list)} \u4e2a\u86cb\u767d\u8d28\u56fe\")\n    except Exception as e:\n        print(f\"  \u9519\u8bef: \u65e0\u6cd5\u52a0\u8f7d\u6570\u636e {e}\")\n        return\n    \n    # 2. \u6279\u91cf\u5e73\u8861\u5904\u7406\n    print(\"\\n&#91;2\/4] \u5f00\u59cb\u5e73\u8861\u5904\u7406...\")\n    balanced_graphs = balance_graph_list(graph_list, device=DEVICE, k_candidates=K_CANDIDATES)\n    \n    # 3. \u4fdd\u5b58\u7ed3\u679c\n    print(\"\\n&#91;3\/4] \u4fdd\u5b58\u5e73\u8861\u540e\u7684\u6570\u636e...\")\n    torch.save(balanced_graphs, OUTPUT_PATH)\n    print(f\"  \u5df2\u4fdd\u5b58\u81f3 {OUTPUT_PATH}\")\n    \n    # 4. \u9a8c\u8bc1\u6307\u6807\n    print(\"\\n&#91;4\/4] \u8ba1\u7b97\u9a8c\u8bc1\u6307\u6807...\")\n    metrics = evaluate.evaluate_generation_quality(graph_list, balanced_graphs, device=DEVICE)\n    \n   # ==================== \u6700\u7ec8\u603b\u7ed3 ====================\n    print(\"\\n\" + \"=\"*60)\n    print(\"\u5904\u7406\u5b8c\u6210\uff01\u6700\u7ec8\u603b\u7ed3\")\n    print(\"=\"*60)\n    \n    # \u57fa\u672c\u4fe1\u606f\n    total_new_nodes = metrics&#91;'summary']&#91;'total_generated_nodes']\n    print(f\"  \u65b0\u589e\u5408\u6210\u8282\u70b9: {total_new_nodes:,} \u4e2a\")\n    \n    # \u4e00\u7ea7\u6307\u6807\uff1a\u7a7a\u95f4\u5408\u7406\u6027\n    coverage_2hop = metrics&#91;'pocket_coverage_graph']&#91;'2hop_coverage']\n    print(f\"  \u56fe\u8df3\u65702\u8df3\u8986\u76d6\u7387: {coverage_2hop:.2%}  {'\u2705 \u8fbe\u6807' if coverage_2hop > 0.6 else '\u274c \u672a\u8fbe\u6807'}\")\n    \n    # \u4e8c\u7ea7\u6307\u6807\uff1a\u62d3\u6251\u4e00\u81f4\u6027\n    jaccard = metrics&#91;'jaccard_similarity']&#91;'mean_score']\n    degree_ratio = metrics&#91;'degree_ratio']&#91;'ratio']\n    print(f\"  Jaccard\u76f8\u4f3c\u5ea6: {jaccard:.4f}  {'\u2705 \u8fbe\u6807' if jaccard > 0.4 else '\u274c \u672a\u8fbe\u6807'}\")\n    print(f\"  \u5ea6\u6570\u6bd4\u7387: {degree_ratio:.2f}  {'\u2705 \u8fbe\u6807' if 0.5 &lt;= degree_ratio &lt;= 1.5 else '\u274c \u672a\u8fbe\u6807'}\")\n    \n    # \u4e09\u7ea7\u6307\u6807\uff1a\u7279\u5f81\u4e0e\u591a\u6837\u6027\n    kl_div = metrics.get('kl_divergence', {}).get('value', 999.0)\n    cos_sim = metrics&#91;'cosine_similarity']&#91;'mean']\n    print(f\"  KL\u6563\u5ea6\/\u7ef4\u5ea6: {kl_div:.4f}  {'\u2705 \u826f\u597d' if kl_div &lt; 0.2 else '\u26a0\ufe0f \u4e00\u822c'}\")\n    print(f\"  \u4f59\u5f26\u76f8\u4f3c\u5ea6: {cos_sim:.4f}  {'\u2705 \u826f\u597d' if cos_sim &lt; 0.85 else '\u26a0\ufe0f \u504f\u9ad8'}\")\n    \n    # \u6574\u4f53\u8bc4\u4f30\u4e0e\u5efa\u8bae\n    print(\"\\n\" + \"-\"*60)\n    print(\"\u6574\u4f53\u8bc4\u4f30\u4e0e\u5efa\u8bae:\")\n    \n    # \u4e00\u7ea7\u6307\u6807\u662f\u5426\u901a\u8fc7\n    if metrics&#91;'pocket_coverage_graph']&#91;'status'] == 'PASS':\n        print(\"  \u2705 \u7a7a\u95f4\u5408\u7406\u6027: \u751f\u6210\u8282\u70b9\u5927\u591a\u843d\u5728\u53e3\u888b2\u8df3\u5185\")\n    else:\n        print(\"  \u274c \u7a7a\u95f4\u5408\u7406\u6027: \u751f\u6210\u8282\u70b9\u504f\u79bb\u53e3\u888b\uff0c\u5efa\u8bae\u964d\u4f4eEDGE_THRESHOLD\u81f30.3\")\n    \n    # \u4e8c\u7ea7\u6307\u6807\u662f\u5426\u901a\u8fc7\n    topology_ok = (metrics&#91;'jaccard_similarity']&#91;'status'] == 'PASS' and \n                   metrics&#91;'degree_ratio']&#91;'status'] == 'PASS')\n    if topology_ok:\n        print(\"  \u2705 \u62d3\u6251\u4e00\u81f4\u6027: \u8fde\u63a5\u6a21\u5f0f\u4e0e\u771f\u5b9e\u6b63\u7c7b\u5339\u914d\")\n    else:\n        print(\"  \u26a0\ufe0f \u62d3\u6251\u4e00\u81f4\u6027: \u5efa\u8bae\u589e\u52a0NUM_EPOCHS_PER_GRAPH\u81f3100\u6216\u8c03\u6574K_CANDIDATES\")\n    \n    # \u4e09\u7ea7\u6307\u6807\n    diversity_ok = (metrics.get('kl_divergence', {}).get('status', 'N\/A') != 'WARN' and \n                    metrics&#91;'cosine_similarity']&#91;'status'] != 'WARN')\n    if diversity_ok:\n        print(\"  \u2705 \u591a\u6837\u6027\u4e0e\u7279\u5f81: \u6a21\u5f0f\u4e30\u5bcc\uff0c\u65e0\u5d29\u6e83\u98ce\u9669\")\n    else:\n        print(\"  \u26a0\ufe0f \u591a\u6837\u6027\u4e0e\u7279\u5f81: \u5efa\u8bae\u76d1\u63a7G_loss\uff0c\u9632\u6b62\u6a21\u5f0f\u5d29\u6e83\")\n    \n    # \u6700\u7ec8\u5efa\u8bae\n    print(\"\\n\" + \"\u6700\u7ec8\u5efa\u8bae:\")\n    if metrics&#91;'pocket_coverage_graph']&#91;'status'] == 'PASS' and topology_ok:\n        print(\"  \ud83c\udf89 \u751f\u6210\u8282\u70b9\u8d28\u91cf\u8fbe\u6807\uff01\u53ef\u76f4\u63a5\u7528\u4e8e\u4e0b\u6e38GCN\u8bad\u7ec3\")\n        print(\"  \ud83d\udccc \u5efa\u8bae\u5728\u9a8c\u8bc1\u96c6\u76d1\u63a7AUPRC\uff0c\u82e5\u4e0b\u964d\u56de\u8c03EDGE_THRESHOLD\")\n    else:\n        print(\"  \ud83d\udd27 \u6838\u5fc3\u6307\u6807\u672a\u8fbe\u6807\uff0c\u5efa\u8bae:\")\n        print(\"     1. \u964d\u4f4eEDGE_THRESHOLD\u81f30.3\")\n        print(\"     2. \u589e\u52a0K_CANDIDATES\u81f35\")\n        print(\"     3. \u589e\u52a0NUM_EPOCHS_PER_GRAPH\u81f3100\")\n    \n    print(\"\\n\u4ee3\u7801\u6267\u884c\u5b8c\u6bd5\uff01\")\n    print(\"=\"*60)\n\n\nif __name__ == '__main__':\n    main()<\/code><\/pre>\n","protected":false},"excerpt":{"rendered":"","protected":false},"author":1,"featured_media":0,"comment_status":"open","ping_status":"open","sticky":false,"template":"","format":"standard","meta":{"footnotes":""},"categories":[25],"tags":[],"class_list":["post-174","post","type-post","status-publish","format-standard","hentry","category-25"],"_links":{"self":[{"href":"https:\/\/snakesleep.work\/index.php?rest_route=\/wp\/v2\/posts\/174","targetHints":{"allow":["GET"]}}],"collection":[{"href":"https:\/\/snakesleep.work\/index.php?rest_route=\/wp\/v2\/posts"}],"about":[{"href":"https:\/\/snakesleep.work\/index.php?rest_route=\/wp\/v2\/types\/post"}],"author":[{"embeddable":true,"href":"https:\/\/snakesleep.work\/index.php?rest_route=\/wp\/v2\/users\/1"}],"replies":[{"embeddable":true,"href":"https:\/\/snakesleep.work\/index.php?rest_route=%2Fwp%2Fv2%2Fcomments&post=174"}],"version-history":[{"count":1,"href":"https:\/\/snakesleep.work\/index.php?rest_route=\/wp\/v2\/posts\/174\/revisions"}],"predecessor-version":[{"id":175,"href":"https:\/\/snakesleep.work\/index.php?rest_route=\/wp\/v2\/posts\/174\/revisions\/175"}],"wp:attachment":[{"href":"https:\/\/snakesleep.work\/index.php?rest_route=%2Fwp%2Fv2%2Fmedia&parent=174"}],"wp:term":[{"taxonomy":"category","embeddable":true,"href":"https:\/\/snakesleep.work\/index.php?rest_route=%2Fwp%2Fv2%2Fcategories&post=174"},{"taxonomy":"post_tag","embeddable":true,"href":"https:\/\/snakesleep.work\/index.php?rest_route=%2Fwp%2Fv2%2Ftags&post=174"}],"curies":[{"name":"wp","href":"https:\/\/api.w.org\/{rel}","templated":true}]}}