JOISC/2019

JOISC19 Mergers

arnold518 2020. 10. 2. 23:03

문제

oj.uz/problem/view/JOI19_mergers

 

문제 보기 - Mergers (JOI19_mergers) :: oj.uz

문제 보기 - Mergers (JOI19_mergers)

oj.uz

트리 상의 각 노드는 특정한 색으로 칠해져 있을 때, 트리가 "분할 가능하다"는 것은 트리를 연결된 두 개의 컴포넌트로 쪼개서, 모든 색이 정확히 하나의 컴포넌트에만 속하게 할 수 있다는 것이다. 서로 다른 두 색을 합칠 수 있을 때, 합치는 연산의 횟수를 최소로 해서 트리가 분할 가능하지 않도록 해야 한다.

$N<=500000$

풀이

트리가 분할 가능하기 위해서는, 트리를 두 컴포넌트로 쪼개야 하므로, 그 사이에 어떠한 간선이 존재해서 양쪽의 색을 완벽히 분할할 수 있어야 한다. 그렇다면, 트리에서 같은 색인 노드만 뽑아서 트리압축을 한 후, 압축된 트리에 해당하는 간선에 색칠한다면, 모든 색에 대해서 색칠을 완료한 후, 분할 가능한 간선은 한번도 색칠되지 않아야 한다.

 

분할 가능한 간선만 남기고 남은 간선은 의미가 없으니, 트리를 다시 한번 압축해보자. 이제 모든 간선이 분할 가능한 간선이 트리가 만들어지고, 서로 다른 두 색을 합치는 과정은 그 색의 서로 다른 두 정점 사이의 모든 분할 가능한 간선들에 대해 색칠을 하는 연산이라 생각할 수 있다.

 

즉, 문제를 다음과 같이 변형할 수 있다.

주어진 트리에서 한번의 연산에 경로를 하나 골라 그 경로의 간선들에 색칠할 수 있을 때, 전체 간선들에 모두 색칠하기 위한 최소 횟수를 구해야 한다.

 

리프 노드에 연결된 간선의 경우에는 어떻게 해도 한번의 연산에 최대 두개씩만 삭제할 수 있다. 따라서 답의 하한은 $ceil((leaf의 수)/2)$이라는 것을 알 수 있다.

Observation 1 : 연산의 횟수가 정확히 $ceil((leaf의 수)/2)$인 최적해가 존재한다.

올림 연산이니, 편의를 위해 리프의 수가 짝수라 가정하자.

리프를 DFS ordering 하고, (1, 1+N/2), (2, 2+N/2), ... 와 같이 묶는다. 만약 이와 같은 페어링에서도 색칠되지 않는 간선이 존재한다 가정하자. 색칠되지 않은 간선의 아래쪽 노드의 서브트리에 해당하는 리프들은 하나의 구간을 이룬다. 어떠한 구간을 잡아 모든 (1, 1+N/2), (2, 2+N/2), ... 쌍이 구간 안에나 밖에만 존재해야 한다.

만약 구간이 [1, N/2], [N/2+1, N]범위에 완전히 포함된다면 무조건 밖으로 나가는 페어가 존재하고, 이가 아니라면 N/2, N/2+1가 구간에 포함되며, 조건을 만족시키기 위해서는 1, N 또한 구간에 포함되어야 해서, 결국 간선의 구간은 트리 전체가 된다. 하지만 루트의 부모는 없으므로, 그러한 간선은 존재하지 않음으로 모순이다.

 

각 색에 대해서 압축된 트리의 간선들을 색칠하는 과정은 단순히 전체 lca를 구하고 트리 dp를 통해 해결할 수 있다.

 

시간 복잡도 : $O(NlogN)$

 

#include <bits/stdc++.h>
using namespace std;
 
typedef long long ll;
typedef pair<int, int> pii;
typedef pair<ll, ll> pll;
 
const int MAXN = 5e5;
 
int N, K;
vector<int> adj[MAXN+10];
vector<int> V[MAXN+10];
 
int L[MAXN+10], R[MAXN+10], cnt, dep[MAXN+10];
int par[MAXN+10][30], A[MAXN+10];
int dp[MAXN+10], P[MAXN+10], deg[MAXN+10];
 
void dfs(int now, int bef, int d)
{
	L[now]=++cnt;
	par[now][0]=bef;
	dep[now]=d;
	for(int nxt : adj[now])
	{
		if(nxt==bef) continue;
		dfs(nxt, now, d+1);
	}
	R[now]=cnt;
}
 
int lca(int u, int v)
{
	if(dep[u]>dep[v]) swap(u, v);
	for(int i=20; i>=0; i--) if(dep[par[v][i]]>=dep[u]) v=par[v][i];
	if(u==v) return u;
	for(int i=20; i>=0; i--) if(par[u][i]!=par[v][i]) u=par[u][i], v=par[v][i];
	return par[u][0];
}
 
void dfs2(int now, int bef)
{
	for(int nxt : adj[now])
	{
		if(nxt==bef) continue;
		dfs2(nxt, now);
		dp[now]+=dp[nxt];
	}
}
 
void dfs3(int now, int bef)
{
	if(dp[now]==0)
	{
		P[now]=now;
		if(now!=bef) deg[now]++, deg[P[bef]]++;
	}
	else P[now]=P[bef];
	for(int nxt : adj[now])
	{
		if(nxt==bef) continue;
		dfs3(nxt, now);
	}
}
 
int main()
{
	scanf("%d%d", &N, &K);
	for(int i=1; i<N; i++)
	{
		int u, v;
		scanf("%d%d", &u, &v);
		adj[u].push_back(v);
		adj[v].push_back(u);
	}
	for(int i=1; i<=N; i++) scanf("%d", &A[i]), V[A[i]].push_back(i);
 
	dfs(1, 1, 1);
	for(int i=1; i<=20; i++) for(int j=1; j<=N; j++) par[j][i]=par[par[j][i-1]][i-1];
 
	for(int i=1; i<=N; i++) dp[i]++;
	for(int i=1; i<=K; i++)
	{
		int w=V[i][0];
		for(auto it : V[i]) w=lca(w, it);
		dp[w]-=V[i].size();
	}
 
	dfs2(1, 1);
	dfs3(1, 1);
 
	int ans=0;
	for(int i=1; i<=N; i++) if(deg[i]==1) ans++;
	printf("%d\n", (ans+1)/2);
}