JOISC19 Cake3
문제
oj.uz/problem/view/JOI19_cake3
문제 보기 - Cake 3 (JOI19_cake3) :: oj.uz
문제 보기 - Cake 3 (JOI19_cake3)
oj.uz
N개의 케이크 조각이 있고, 이중 M개를 뽑아 적당히 배열하여 $\sum_{i=1}^{M}V_i-\sum_{i=1}^{M}|C_i-C_{i+1}|$값을 최대화해야 한다.
$M<=N<=200000$
풀이
먼저, M개의 조각을 모두 고른 후 배열하는 방법을 생각해 보자.
$\sum_{i=1}^{M}V_i$는 고정이고, $\sum_{i=1}^{M}|C_i-C_{i+1}|$를 최소화해야 한다.
Observation 1 : 최적의 배치는 $C_i$의 오름차순으로 정렬하여 배치하는 것이다.
어떻게 배치하더라도 최소에서 최대로, 최대에서 최소로 가는 경로가 존재하니, 답의 하한은 (최대-최소)*2이고, 실제로 이를 정렬하여 배치하면 그 값을 얻을 수 있다.
즉, 비용 함수는 $\sum_{i=1}^{M}V_i-2(C_{max}-C_{min})$와 같다.
$C_{max}$와 $C_{min}$을 고정하면 C로 정렬된 배열에서 특정 구간이 나온다. 그 구간 중에 M개 원소의 $V_i$의 합을 최대화하기 위해서는 가장 큰 M개를 선택하면 된다.
이제, 문제는 C로 정렬된 배열에서 특정 구간을 잡아 구간 내에 상위 $M$개의 원소의 합을 구하면 비용을 알 수 있게 되었다.
모든 구간에 대하여 다 계산하면 $O(N^2logN)$등에 계산할 수 있다.
봐야 하는 구간의 개수를 줄이기 위해, Divide and Conquer Optimization을 사용하자.
Observation 2 : 구간에 대한 비용 함수는 monge array 이다.
$a<=b<=c<=d$일 때 $f(a, c)+f(b, d)>=f(a, d)+f(b, c)$임을 보이자.
$g(l, r)=$구간에서 상위 M개의 합이라 정의하면, 비용 함수는 $g(l, r)-2(C_r-C_l)$이다.
식을 정리하면 C에 대한 항은 모두 소거되고 $g(a, c)+g(b, d)>=g(a, d)+g(b, c)$, 즉 g가 monge함을 보이면 된다.
귀납법을 이용하자.
M=1일 때 전체 구간의 최댓값 k가 두 구간의 공통부분, 즉 [b, d]에 위치한다면 $g(a, c)+g(b, d)=g(a, d)+g(b, c)$이니 성립한다.
만약 공통부분에 최댓값이 위치하지 않는다면 일반성을 잃지 않고 [a, b)에 위치한다 하자.
$g(a, c)=k, g(a, d)=k$이니, $g(b, d)>=g(b, c)$만 보이면 되는데, $g(b, d)$가 $g(b, c)$를 완전히 포함하니 이는 자명히 성립한다.
M>1일 때도 이미 선택한 원소들을 모두 제거하고 생각하면 M=1일 때와 비슷한 논리로 새로 선택할 원소 또한 조건을 만족하니, 귀납적으로 모든 M에 대해 성립한다.
보아야 하는 구간의 개수가 $O(NlogN)$으로 줄었으니, $f(l, r)$만 빠르게 구할 수 있으면 된다. 구간 내에서 상위 $M$개의 값을 구하는 연산은 Persistent Segment Tree에서 세그먼트 트리 위에서의 이분탐색으로 해결할 수 있으니, $O(logN)$으로 해결할 수 있고, 따라서 전체 문제는 $O(Nlog^2N)$에 풀린다.
시간 복잡도 : $O(Nlog^2N)$
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<int, int> pii;
typedef pair<ll, ll> pll;
const int MAXN = 2e5;
struct Cake
{
int V, C;
bool operator < (const Cake &p) { return C<p.C; }
}A[MAXN+10];
int N, M, S;
ll ans=-1e18;
vector<int> comp;
int getcomp(int x) { return lower_bound(comp.begin(), comp.end(), x)-comp.begin(); }
struct Node
{
ll cnt, sum;
Node *lc, *rc;
Node() : cnt(0), sum(0), lc(NULL), rc(NULL) {}
};
Node *tree[MAXN+10];
void makeTree(Node *node, int tl, int tr)
{
if(tl==tr) return;
int mid=tl+tr>>1;
node->lc=new Node();
node->rc=new Node();
makeTree(node->lc, tl, mid);
makeTree(node->rc, mid+1, tr);
}
Node *addTree(Node *node, int tl, int tr, int pos)
{
if(pos<tl || tr<pos) return node;
Node *ret=new Node();
if(tl==tr)
{
ret->cnt=node->cnt+1;
ret->sum=node->sum+comp[tl];
return ret;
}
int mid=tl+tr>>1;
ret->lc=addTree(node->lc, tl, mid, pos);
ret->rc=addTree(node->rc, mid+1, tr, pos);
ret->cnt=ret->lc->cnt+ret->rc->cnt;
ret->sum=ret->lc->sum+ret->rc->sum;
return ret;
}
ll query(Node *nodel, Node *noder, int tl, int tr, ll k)
{
if(tl==tr) return comp[tl]*min(k, noder->cnt-nodel->cnt);
int mid=tl+tr>>1;
ll t=noder->rc->cnt-nodel->rc->cnt;
if(k>t) return query(nodel->lc, noder->lc, tl, mid, k-t)+noder->rc->sum-nodel->rc->sum;
else return query(nodel->rc, noder->rc, mid+1, tr, k);
}
void solve(int sl, int sr, int el, int er)
{
if(sl>sr) return;
int smid=sl+sr>>1, emid=er; ll val=-1e18;
for(int i=max(smid+M-1, el); i<=er; i++)
{
ll now=query(tree[smid-1], tree[i], 0, S-1, M)-2ll*(A[i].C-A[smid].C);
//printf("%d %d %lld\n", smid, i, now);
if(now>val) val=now, emid=i;
}
ans=max(ans, val);
solve(sl, smid-1, el, emid);
solve(smid+1, sr, emid, er);
}
int main()
{
scanf("%d%d", &N, &M);
for(int i=1; i<=N; i++) scanf("%d%d", &A[i].V, &A[i].C), comp.push_back(A[i].V);
sort(A+1, A+N+1);
sort(comp.begin(), comp.end());
comp.erase(unique(comp.begin(), comp.end()), comp.end());
S=comp.size();
tree[0]=new Node();
makeTree(tree[0], 0, S-1);
for(int i=1; i<=N; i++) tree[i]=addTree(tree[i-1], 0, S-1, getcomp(A[i].V));
solve(1, N, 1, N);
printf("%lld", ans);
}