본문 바로가기
알고리즘/백준

[BOJ 1517] 버블소트(c++)

by umdoyun 2025. 11. 13.
728x90

https://www.acmicpc.net/problem/1517

문제 개요

 N개의 수로 이루어진 수열에 대해 버블 소트를 수행할 때, swap이 총 몇 번 발생하는지 구하는 문제입니다.

 

문제 풀이

 버블 소트는 서로 인접한 두 수를 비교하여 큰 수를 뒤로 보내는 정렬 방법입니다. 모든 인접한 쌍에 대해 반복적으로 비교하여 swap합니다.

 초기에는 세그먼트 트리로 구간의 최댓값과 최솟값을 비교하여 재귀적으로 각 요소들의 뒤에 요소보다 작은 수의 개수를 직접세어 구현을 하였는데, 시간초과가 났습니다. 아마 대부분의 경우에 세그먼트 트리의 조기 종료 조건에 부합하지 않아 최악의 경우 O(n^2)이 되었던 것 같습니다.그래서 값과 기존 인덱스를 저장하여 정렬 후 변화량을 기록하여 기존 위치로부터의 swap의 수를 구하고 이동한 것을 기록하는 방법으로 바꾸어 처리를 했습니다.

 다른 풀이법으로는 병합 정렬을 구현하여 O(n log n)의 시간으로 버블 소트의 스왑 수를 계산하였습니다. 버블 정렬을 하면 각각의 분할된 리스트를 병합할때 이미 리스트는 정렬되어 있기 때문에 선형적으로 비교하며 정렬할 수 있습니다. 이때 리스트가 합쳐지는 과정에서의 스왑 수를 구하면 버블소트와 동일한 스왑 수를 구할 수 있었습니다.

 

코드

세그먼트 트리

#include <iostream>
#include <algorithm>
using namespace std;

int n;
pair<int, int> arr[500001];
int seg[500001 * 4];

void update(int x, int s, int e, int idx, int val) {
	if (idx < s || idx > e) return;
	if (s == e) {
		seg[x] = val;
		return;
	}
	int mid = s + (e - s) / 2;
	update(x * 2, s, mid, idx, val);
	update(x * 2 + 1, mid + 1, e, idx, val);
	seg[x] = seg[x * 2] + seg[x * 2 + 1];
}

int query(int x, int s, int e, int l, int r) {
	if (r < s || e < l) return 0;
	if (l <= s && e <= r) return seg[x];
	int mid = s + (e - s) / 2;
	return query(x * 2, s, mid, l, r) + query(x * 2 + 1, mid + 1, e, l, r);
}

int main() {
	ios::sync_with_stdio(0);
	cin.tie(0);
	cout.tie(0);

	cin >> n;
	for (int i = 0; i < n; i++) {
		cin >> arr[i].first;
		arr[i].second = i;
	}

	sort(arr, arr + n);

	long long res = 0;
	for (int i = 0; i < n; i++) {
		int original_idx = arr[i].second;
		int cnt = query(1, 0, n - 1, original_idx + 1, n - 1);
		res += cnt;
		update(1, 0, n - 1, original_idx, 1);
	}

	cout << res;
	return 0;
}

 

병합 정렬

#include <iostream>
#include <algorithm>
using namespace std;

int n;
int arr[500001];
int tmp[500001];

long long res;

void merge_sort(int s, int e) {
	if (s >= e) return;
	int mid = (s + e) / 2;
	merge_sort(s, mid);
	merge_sort(mid + 1, e);

	int i = s, j = mid + 1, k = s;
	while (i <= mid && j <= e) {
		if (arr[i] <= arr[j]) {
			tmp[k++] = arr[i++];
		}
		else {
			tmp[k++] = arr[j++];
			res += mid - i + 1;
		}
	}
	while (i <= mid) tmp[k++] = arr[i++];
	while (j <= e) tmp[k++] = arr[j++];

	for (int i = s; i <= e; i++) {
		arr[i] = tmp[i];
	}
}

int main() {
	ios::sync_with_stdio(0);
	cin.tie(0);
	cout.tie(0);
	cin >> n;
	for (int i = 0; i < n; i++) {
		cin >> arr[i];
	}
	merge_sort(0, n - 1);
	cout << res;
	return 0;
}