Divide & Conquer DP Optimization 알고리즘
Divide & Conquer DP Optimization 알고리즘
지난 Knuth Optimization에 이어 이번에는 Divide & Conquer Optimization에 대해 소개하도록 하겠다.
1. 예시 문제: 김치(BOJ 11001)
https://www.acmicpc.net/problem/11001
이 문제는 $1 \leq q \leq p \leq N, p \leq q+D$일 때 $max(T[p]*(q-p)+V[q])$를 구하는 문제이다. 이를 단순히 반복문만을 이용하여 해결하게 되면 $O(N^2)$으로 TLE가 나게 된다.
2. Divide & Conquer Optimization의 조건과 최적화
Divide & Conquer Optimization은 점화식이 다음과 같은 조건을 만족할 때 사용할 수 있다.
조건 1) $DP[t][i]=min_{k<i}(DP[t-1][k]+C[k][i])$
조건 2) $A[t][i]=C[t][i]$를 만족시키는 최소 $k$라고 할 때 $A[t][i] \leq A[t][i+1]$가 만족
조건 2’) $a \leq b \leq c \leq d$ 일때 $C[a][c]+C[b][d] \leq C[a][d]+C[b][c]$ 를 만족한다.
조건 1과 조건 2나 조건 2’을 만족하면, $DP[t][i]$를 계산했을 때 $k=p$일때 최소값이 나왔다면, $DP[t][x<i]$를 구할 때에는 $k \leq p$에 대해서만 살펴보면 되고, $DP[t][x>i]$를 구할 때에는 $k \le p$에 대해서만 살펴보면 된다.
따라서 이를 이용하면 $O(KN^{2})$의 DP를 $O(KNlgN)$까지 줄일 수 있다. 이 때의 코드는 다음과 같다.
#include<bits/stdc++.h>
#define MAXN 5000
#define MAXK 1000
#define oo 1987654321
typedef long long ll;
using namespace std;
int DP[MAXK][MAXN],C[MAXN][MAXN];
void dc_opt(int t,int s,int e,int l,int r) // DP[t][s~e]를 k의 범위 p~q 내에서 계산한다.
{
if(s>e) return;
int mid=(s+e)/2;
DP[t][mid]=oo;
int v,w;
for(int k=l;k<=min(r,m-1);k++)
{
v=DP[t-1][k]+C[k][mid];
if(DP[t][mid]>v) DP[t][mid]=v,w=k;
}
dc_opt(t,s,mid-1,l,w);
dc_opt(t,mid+1,e,w,r);
}
3. Divide & Conquer Optimization의 사용
우선 김치 문제를 $O(N^{2})$로 해결할 때 사용했던 식을 다시 한 번 살펴보자. \(max(T[p]\*(q-p)+V[q]\)
일단 $C[q][p]=T[p]*(q-p)+V[q]$라고 하면, $a \leq b \leq c \leq d$ 일때 $C[a][c]+C[b][d]-C[a][d]+C[b][c] \le 0$을 언제나 만족함을 확인할 수 있다.
따라서 $T’[i]=-T[i]$, $V’[i]=-V[i]$로 재정의하면, 구하는 것은 $-min(T’[p]*(q-p)+V’[q]$ 가 되고, $C’[q][p]$는 $C’[a][c]+C’[b][d]-C’[a][d]+C’[b][c] \leq 0$을 만족한다.
$DP[p]=-min(T’[p]*(q-p)+V’[q]$라고 두면 이를 이용하여 문제를 해결할 수 있다.
단 이 때 $D$ 조건을 고려하여 가능한 $q$의 범위를 적당히 잘 설정해주어야 한다.
코드는 다음과 같다.
#include<bits/stdc++.h>
#define oo 1987654321
typedef long long ll;
using namespace std;
int N,D,DP[111111];
long long T[111111],V[111111],ans;
void dc_opt(int s, int e, int l, int r)//DP[s]~DP[e]를 구하는 함수, 이 때 k의 범위
{
if(s>e) return ;
int mid=(s+e)/2;
ll maxi=0;
int tp=0;
for(int i=max(l,mid-D);i<=min(mid,r);i++)
{
if(T[mid]*(mid-i)+V[i]>maxi) maxi=T[mid]\*(mid-i)+V[i],tp=i;
}
if(ans<maxi) ans=maxi;
dp_opt(s,mid-1,l,tp);
dp_opt(mid+1,e,tp,r);
}
int main(void)
{
scanf("%d %d",&N,&D);
for(int i=1;i<=N;i++) scanf("%d",&T[i]);
for(int i=1;i<=N;i++) scanf("%d",&V[i]);
dc_opt(1,N,1,N);
printf("%lld",ans);
}