点分树


先开个坑。

考虑一个树分治的经典问题:给出一棵带边权树、黑点集合、白点集合,然后求出每个黑点到每个白点的距离和。

那么只需枚举分治中心得到的所有黑点,然后乘上白点的距离总和,然后减去所有儿子做同样过程加上儿子到自己的距离。

总之就是总的贡献、减去子树内部的贡献、再减去子树内部到自己的一些额外贡献。


然后考虑支持修改 的颜色,我们发现计算了 的贡献的分治中心只会有 个,于是我们暴力枚举这个分治中心,记录 黑点/白点 到当前点的点数以及距离总和,和 黑点/白点 到分治父亲的距离总和即可计算。

注意点分树上父子的方向关系没有传递性,所以所有的距离都必须直接求出而不能累加。

CODE

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
void add(int &x,int y){if((x+=y)>=mo) x-=mo;}
void dec(int &x,int y){if((x-=y)<0) x+=mo;}
struct tree{
int dfsn,rk[N+N],dp[N],len[N];
int sz[N],ans,tfa[N],cnt[N][2],up[N][2],down[N][2];
bool vis[N],col[N];
pair<int,int> st[20][N+N];
vector<pair<int,int>> t[N];
void push(int x,int y,int w){
t[x].emplace_back(y,w);
t[y].emplace_back(x,w);
}
int lca(int x,int y){
x=rk[x];y=rk[y];
if(x>y) swap(x,y);
int s=31-__builtin_clz(y-x+1);
return(min(st[s][x],st[s][y-(1<<s)+1]).second);
}
int dis(int x,int y){
int l=lca(x,y),re=len[x]+len[y];
if((re-=len[l])<0) re+=mo;
if((re-=len[l])<0) re+=mo;
if(re>=mo) re-=mo;
return(re);
}
void dfs0(int x,int fa=0){
st[0][rk[x]=++dfsn]={dp[x]=dp[fa]+1,x};
sz[x]=1;
for(auto i:t[x])
if(!vis[i.first]&&i.first!=fa){
if((len[i.first]=len[x]+i.second)>=mo) len[i.first]-=mo;
dfs0(i.first,x);
sz[x]+=sz[i.first];
st[0][++dfsn]={dp[x],x};
}
}
int dfs(int x,int dp=0,int fa=0){
sz[x]=1;
if(dp>=mo) dp-=mo;
int re=dp;
for(auto i:t[x])
if(!vis[i.first]&&i.first!=fa){
add(re,dfs(i.first,dp+i.second,x));
sz[x]+=sz[i.first];
}
return(re);
}
int root(int x){
int all=sz[x];
for(auto i=t[x].begin();i!=t[x].end();){
int to=i->first;
if(!vis[to]&&sz[to]<<1>=all&&sz[to]<sz[x])
i=t[x=to].begin();
else ++i;
}
return(x);
}
void solve(int x,int fa){
tfa[x]=fa;
for(int i=x;tfa[i];i=tfa[i])
add(up[i][0],dis(x,tfa[i]));
vis[x]=1;
down[x][0]=dfs(x);
cnt[x][0]=sz[x];
for(auto i:t[x])
if(!vis[i.first])
solve(root(i.first),x);
}
void modify(int x){
bool c=col[x];
col[x]^=1;
for(int s=x,pre=0,prel=0;x;x=tfa[x]){
int now=dis(x,s),nxt=dis(s,tfa[x]);
dec(down[x][c],now);
dec(up[x][c],nxt);
--cnt[x][c];
ans=(ans+down[x][c]-down[x][c^1]+now*ll(cnt[x][c]-cnt[x][c^1]-pre)-prel)%mo;
if(ans<0) ans+=mo;
pre=cnt[x][c]-cnt[x][c^1];
prel=up[x][c]-up[x][c^1];
++cnt[x][c^1];
add(down[x][c^1],now);
add(up[x][c^1],nxt);
}
}
void build(){
dfs0(1);
for(int i=1;i<20;++i)
for(int j=1,r=1<<(i-1);j+r+r-1<=n+n;++j)
st[i][j]=min(st[i-1][j],st[i-1][j+r]);
solve(root(1),0);
}
};