728x90
https://www.acmicpc.net/problem/13510

문제 개요
N개의 정점으로 이루어진 트리가 있고, 각 간선에는 비용이 있다. 아래 두 쿼리를 수행하는 프로그램을 작성해야 합니다.
- 1 i c: i번 간선의 비용을 c로 변경
- 2 u v: u에서 v로 가는 경로에 있는 간선 비용 중 최댓값 출력
2 ≤ N ≤ 100,000
1 ≤ M ≤ 100,000 (쿼리 개수)
간선 비용 ≤ 1,000,000
문제 풀이
이 문제는 Heavy-Light Decomposition(HLD)를 사용하여 해결할 수 있었습니다.
Heavy-Light Decomposition이란?
HLD는 트리를 여러 개의 체인으로 분해하여 경로 쿼리를 효율적으로 처리하는 기법입니다.
핵심 아이디어:
- 각 노드에서 서브트리 크기가 가장 큰 자식으로 가는 간선을 Heavy Edge라고 함
- 나머지 간선은 Light Edge
- Heavy Edge로 연결된 노드들을 하나의 체인으로 묶음
- 각 체인을 세그먼트 트리로 관리
왜 효율적인가?
- 임의의 두 정점 사이 경로는 최대 O(logN)개의 체인을 지나감
- 각 체인에서의 쿼리는 세그먼트 트리로 O(logN)에 처리
- 따라서 경로 쿼리는 O(log²N)
구현 과정
1단계: DFS로 Heavy Child 결정
각 노드의 서브트리 크기를 계산하고, 가장 큰 서브트리를 가진 자식을 Heavy로 표시한다.
int dfs(int x, int pre) {
node[x].par = pre;
node[x].heavy = -1;
int sum = 1, mx = -1;
for (auto next : edges[x]) {
if (pre == next.first) continue;
cost[next.first] = next.second;
int sz = dfs(next.first, x);
sum += sz;
if (mx < sz) {
node[x].heavy = next.first;
mx = sz;
}
}
return sum;
}
2단계: HLD 분해
Heavy Edge를 따라 같은 체인으로 묶고, 각 노드에 세그먼트 트리 인덱스를 부여한다.
void hld(int x, int top, int d) {
node[x].pos = cnt++; // 세그먼트 트리 인덱스
node[x].top = top; // 체인의 최상단 노드
node[x].depth = d;
// Heavy child를 먼저 방문 (같은 체인 유지)
if (node[x].heavy != -1) {
hld(node[x].heavy, top, d + 1);
}
// Light children 방문 (새 체인 시작)
for (auto next : edges[x]) {
if (next.first == node[x].par || next.first == node[x].heavy)
continue;
hld(next.first, next.first, d + 1);
}
}
3단계: 경로 쿼리
두 정점이 같은 체인에 올 때까지 체인을 타고 올라가며 최댓값을 구한다.
int path_query(int x, int y) {
int ret = 0;
while (node[x].top != node[y].top) {
// 더 깊은 체인을 올림
if (node[node[x].top].depth < node[node[y].top].depth)
swap(x, y);
// 현재 체인에서 쿼리
ret = max(ret, query(1, 0, n - 1,
node[node[x].top].pos, node[x].pos));
x = node[node[x].top].par;
}
// 같은 체인 내에서 쿼리
if (node[x].depth > node[y].depth) swap(x, y);
if (node[x].pos < node[y].pos) {
ret = max(ret, query(1, 0, n - 1,
node[x].pos + 1, node[y].pos));
}
return ret;
}
4단계: 간선 업데이트
간선의 비용은 두 노드 중 깊이가 더 깊은 노드의 위치에 저장한다.
int get_node(int idx) {
int x = edge_info[idx].first;
int y = edge_info[idx].second;
return node[x].depth > node[y].depth ? x : y;
}
시간 복잡도
- 전처리: O(N)
- 업데이트: O(logN)
- 경로 쿼리: O(log²N)
코드
#include <iostream>
#include <vector>
#include <algorithm>
using namespace std;
const int MAX = 100001;
int n, m, cnt;
vector<pair<int, int>> edges[MAX];
pair<int, int> edge_info[MAX];
int cost[MAX];
class Node {
public:
int pos, par, top, depth, heavy;
}node[MAX];
int seg[MAX * 4];
int dfs(int x, int pre) {
node[x].par = pre;
node[x].heavy = -1;
int sum = 1, mx = -1;
for (auto next : edges[x]) {
if (pre == next.first) continue;
cost[next.first] = next.second;
int sz = dfs(next.first, x);
sum += sz;
if (mx < sz) {
node[x].heavy = next.first;
mx = sz;
}
}
return sum;
}
void hld(int x, int top, int d) {
node[x].pos = cnt++;
node[x].top = top;
node[x].depth = d;
if (node[x].heavy != -1) {
hld(node[x].heavy, top, d + 1);
}
for (auto next : edges[x]) {
if (next.first == node[x].par || next.first == node[x].heavy) continue;
hld(next.first, next.first, d + 1);
}
}
int update(int x, int s, int e, int idx, int v) {
if (idx < s || e < idx) return seg[x];
if (s == e) return seg[x] = v;
int mid = s + (e - s) / 2;
return seg[x] = max(update(x * 2, s, mid, idx, v), update(x * 2 + 1, mid + 1, e, idx, v));
}
int query(int x, int s, int e, int l, int r) {
if (r < s || e < l) return 0;
if (l <= s && e <= r) return seg[x];
int mid = s + (e - s) / 2;
return max(query(x * 2, s, mid, l, r), query(x * 2 + 1, mid + 1, e, l, r));
}
int path_query(int x, int y) {
int ret = 0;
while (node[x].top != node[y].top) {
if (node[node[x].top].depth < node[node[y].top].depth) swap(x, y);
ret = max(ret, query(1, 0, n - 1, node[node[x].top].pos, node[x].pos));
x = node[node[x].top].par;
}
if (node[x].depth > node[y].depth) swap(x, y);
if (node[x].pos < node[y].pos) {
ret = max(ret, query(1, 0, n - 1, node[x].pos + 1, node[y].pos));
}
return ret;
}
int get_node(int idx) {
int x = edge_info[idx].first;
int y = edge_info[idx].second;
return node[x].depth > node[y].depth ? x : y;
}
int main() {
ios::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
cin >> n;
int x, y, c;
for (int i = 1; i < n; i++) {
cin >> x >> y >> c;
edges[x].push_back({ y, c });
edges[y].push_back({ x, c });
edge_info[i] = { x, y };
}
dfs(1, 0);
hld(1, 1, 1);
for (int i = 1; i <= n; i++) {
update(1, 0, n - 1, node[i].pos, cost[i]);
}
cin >> m;
for (int i = 0; i < m; i++) {
cin >> c >> x >> y;
if (c == 1) {
x = get_node(x);
update(1, 0, n - 1, node[x].pos, y);
}
else {
int ret = path_query(x, y);
cout << ret << '\n';
}
}
return 0;
}'알고리즘 > 백준' 카테고리의 다른 글
| [BOJ1615] 교차개수세기(c++) (1) | 2026.02.11 |
|---|---|
| [BOJ 14727] 퍼즐 자르기 (c++) (0) | 2026.02.04 |
| [BOJ 20212] 나무는 쿼리를 싫어해~(c++) (1) | 2026.01.13 |
| [BOJ 5419] 북서풍(c++) (0) | 2026.01.13 |
| [BOJ 3002] 아날로그 다이얼(c++) (0) | 2025.12.30 |