Algorithm/Baekjoon

[백준] 21384 : Average Distance

enopid 2025. 4. 7. 12:40

https://www.acmicpc.net/problem/21384

문제 분석

문제 자체는 간단한다. 트리가 있고 엣지별로 길이가 주어질 때 노드 사이의 길이의 평균을 구하는 문제이다. 하지만, 단순히 모든 사이의 길이를 방문하면 총 방문 횟수가 ${}_{10000}C_2 \approx 10^{8}$으로 제한시간내에 문제가 안풀린다. 따라서 모든 노드사이의 길이를 직접 구하지 않고 그길이들의 합만 간접적으로 구할 수있는 트릭이 필요하다.
일단 하나의 노드를 기준으로 해당 노드와 그 자식들로 이루어진 길이만 구한다고 생각하자. 만약, 루트 노드에서 해당 과정을 거치면 전체 트리에서의 길이를 구할 수 있을 것이다. 관건은 해당 노드의 직계 자식 노드을 기준으로 직계 자식 노드와 그 자식들로 이루어진 길이를 바탕으로 해당 노드를 표현하는 것이다. 해당 노드와 그 자식들로 이루어진 길이는 3가지 종류로 나뉜다. 해당 노드를 끝지점으로 가지는 길이, 해당 노드를 지나지않는 길이, 해당 노드를 경유하는 길이 이 3가지 종류는 아래와 같이 부르겠다. 추가적인 연산을 위한 해당 노드를 끝지점으로 가지는 길이의 수도 필요하다.

크게 3가지 종류의 길이의 합과 하가지의 길이의 수 정보가 필요하다.
1. terminalPathSum : 해당 노드가 시작이나 끝 지점인 길이의 합
2. terminalPathNum : 해당 노드가 시작이나 끝 지점인 길이의 수
3. excludedPathSum : 해당 노드를 안 거치는 길이의 합
4. throughPathSum : 해당 노드를 경유하는 길이의 합

이제 특정 노드의 위의 4요소는 자식 노드의 요소들로 표현가능하다.
terminalPathSum 자식 노드의 terminalPathSum에서 자식 노드와 원래 노드사이의 엣지를 추가한케이스와 엣지만 존재하는 케이스로 표현가능하다.
terminalPathNum은 단순히 자식의 terminalPathNum에서 엣지만 있는 케이스 하나만 추가하면 된다.
excludedPathSum은 자식의 terminalPathSum, excludedPathSum, throughPathSum을 모두 더해주면 된다.
throughPathSum은 좀 복잡한데 방금 구한 terminalPathSum을 겹치지 않는 자식 페어마다 계산해주면 된다. 다만, 단순히 모든 페어를 계산하면 $N^2$의 복잡도가 필요해 $ab+bc+ac=(a+b+c)^2-(a^2+b^2+c^2)$처럼 계산 되는 방식을 이용해 복잡도를 줄인다. 

1. $terminalPathSum_i = \sum_{j \in child} (terminalPathSum_j+(terminalPathNum_j+1)*dist_{ij})$
2. $terminalPathNum_i = \sum_{j \in child} (terminalPathNum_j+1)$
3. $excludedPathSum_i = \sum_{j \in child} (terminalPathSum_j+excludedPathSum_j+throughPathSum_j)$
4. $throughPathSum_i = \sum_{j\in child} \sum_{k \in child} (terminalPathSum_j+(terminalPathNum_j+1)*dist_{ij})$$(terminalPathSum_k+(terminalPathNum_k+1)*dist_{ik})$ $(j!=k)$

코드

#include <iostream>
#include <algorithm>
#include <vector>
#include <cassert>
#include <iomanip>

using namespace std;

//문제분석
//특정 노드 기준의 길이는 세 종류로 나뉜다.
//다음 길이들은 해당 노드와 해당 노드의 자식들로만 이루어진 경우이다.(부모는 미포함) 
//1. terminalPathSum    : 해당 노드가 시작이나 끝 지점인 길이의 합
//2. terminalPathNum    : 해당 노드가 시작이나 끝 지점인 길이의 수
//3. excludedPathSum    : 해당 노드를 안 거치는 길이의 합
//4. throughPathSum     : 해당 노드를 경유하는 길이의 합
//
// 특정 노드를 기준으로 위의 4가집값은 자식들로 표현가능하다.
//1. terminalPathSum    =  자식 노드들의 terminalPath에 거리를 더한 값들을 더 해서 구함
//2. terminalPathNum    =  자식 노드를의 (terminalPathNum+1)을 더해서 구함
//3. excludedPathSum    =  자식 노드들의 (terminalPath+excludedPathSum+throughPathSum)을 더 해서 구함
//4. throughPathSum     =  
// terminalPathSum^parent   : \sum (terminalPathSum^child+(terminalPathNum^child+1)*d^child)
// terminalPathNum          : \sum (terminalPathNum^child+1)
// excludedPathSum          : \sum (terminalPathSum^child+excludedPathSum^child+throughPathSum^child)
// throughPathSum           : 
//
// 각 길이의 최대 값    : 1000(n-1)=10^7
// 가능한 길이의 종류   : nC2=10^8
// 모든 길이의 합       : 10^15=(10^3)^5=2^50
//=====================================================================
//해결전략
//
//=====================================================================
//필요자료형
//변수
//
//함수
//

using LL=long long;

struct Node{
public:
    static void ClearNeighbours(const int&N){
        neighbours = vector<vector<pair<int,int>>>(N,vector<pair<int,int>>());
    }
    static void AddNeighbour(const int& a, const int& b, const int& d){
        neighbours[a].push_back({b,d});
        neighbours[b].push_back({a,d});
    }

    Node(const int& curind, const int& parentind){
        for(const auto& child : neighbours[curind]){
            int childind    =child.first;
            int dist        =child.second;
            if (childind==parentind) continue;

            Node childNode(childind, curind);
            terminalPathSum +=childNode.terminalPathSum+dist*(childNode.terminalPathNum+1);
            terminalPathNum +=(childNode.terminalPathNum+1);
            excludedPathSum +=childNode.terminalPathSum+childNode.excludedPathSum+childNode.throughPathSum;
            throughPathSum  +=(childNode.terminalPathNum+1)*(childNode.terminalPathSum+dist*(childNode.terminalPathNum+1));
        }
        throughPathSum=terminalPathNum*terminalPathSum-throughPathSum;
    }

    double GetAverageDistance(){
        double edgeNum = (neighbours.size())*(neighbours.size()-1)/2.0;
        return (terminalPathSum+excludedPathSum+throughPathSum)/edgeNum;
    }
private:
    static vector<vector<pair<int,int>>> neighbours;
    LL  terminalPathSum=0;
    LL  excludedPathSum=0;
    LL  throughPathSum=0;
    int terminalPathNum=0;
};
vector<vector<pair<int,int>>> Node::neighbours;

int main()
{
    int T;
    cin >> T;
    vector<double> answers;
    for(int t=0; t<T; t++)
    {
        int N;
        cin >> N;
        Node::ClearNeighbours(N);
        for(int i=0; i<N-1; i++){
            int a, b, d;
            cin >> a >> b >> d;
            Node::AddNeighbour(a,b,d);
        }
        Node root(0,0);
        answers.push_back(root.GetAverageDistance());
    }
    cout << fixed << setprecision(7);
    for(auto ans: answers) cout << ans << endl;
}

'Algorithm > Baekjoon' 카테고리의 다른 글

[백준] 5855 : Square Overlap  (0) 2025.04.13
[백준] 10901 : Make superpalindrome!  (0) 2025.04.12
[백준] 23250 : 하노이K  (0) 2025.04.05
[백준] 29761 : 물 뿌리기  (0) 2025.04.04
[백준] 29820 : Love Letter  (0) 2025.03.31