点分治
<p>下面是求所有路径最大值乘最小值之和的(路径两端点可以相同)</p>
<pre><code class="language-cpp">#include <iostream>
#include <vector>
#include <algorithm>
//#include <unordered_map>
using namespace std;
typedef long long ll;
typedef pair<ll, ll> Node;
typedef pair<ll, int> Tree;
struct Num{
ll num;
int place;
int rnk;
};
const int maxn = 1e5+5;
const ll mod = 998244353;
//unordered_map<int, int> rnk;
vector<int> mp[maxn];
int n, k;
Num w[maxn];
ll sa[maxn];
//siz表示子树大小,sonsiz表示重儿子子树的大小
int siz[maxn], sonsiz[maxn] = {maxn}, root;
bool vis[maxn];
vector<Node> dis;//存放中间内容
int subn;//求根时子树大小
ll ans = 0;
ll qadd(ll a, ll b){
ll sum = a + b;
while(sum < 0) sum += mod;
while(sum >= mod) sum -= mod;
return sum;
}
Tree tree[maxn];
int lowerbit(int x){
return x & -x;
}
void add(int z, ll x){
for(int i = z; i <= n; i = i + lowerbit(i)){
tree[i].first = qadd(tree[i].first, x);
tree[i].second += 1;
}
}
void sub(int z, ll x){
for(int i = z; i <= n; i = i + lowerbit(i)){
tree[i].first = qadd(tree[i].first, -x);
tree[i].second -= 1;
}
}
ll sum1(int z){
ll sum = 0;
for(int i = z; i > 0; i = i - lowerbit(i)){
sum = qadd(sum, tree[i].first);
}
return sum;
}
int sum2(int z){
int sum = 0;
for(int i = z; i > 0; i = i - lowerbit(i)){
sum += tree[i].second;
}
return sum;
}
void getroot(int u, int fa){//n为树的大小,求重心
siz[u] = 1;
sonsiz[u] = 0;
for(int i = 0; i < mp[u].size(); i++){
int v = mp[u][i];
if(vis[v] || v == fa) continue;
getroot(v, u);
siz[u] = siz[u] + siz[v];
sonsiz[u] = max(sonsiz[u], siz[v]);
}
sonsiz[u] = max(sonsiz[u], subn - siz[u]);
if(sonsiz[u] < sonsiz[root]) root = u;
}
void getdis(int u, int fa, int maxdist, int mindist){
dis.push_back(Node(maxdist, mindist));
for(int i = 0; i < mp[u].size(); i++){
int v = mp[u][i];
if(vis[v] || v == fa) continue;
getdis(v, u, max(maxdist, w[v].rnk), min(mindist, w[v].rnk));
}
return;
}
bool comp(Num a, Num b){
return a.num < b.num;
}
bool comp1(Num a, Num b){
return a.place < b.place;
}
bool comp3(Node a, Node b){
return a.second < b.second;
}
ll calc(int u, int maxlen, int minlen){ //这个用来求解的函数,自定义
ll ans = 0;
dis.clear();
getdis(u, 0, maxlen, minlen);
sort(dis.begin(), dis.end());
for(int i = 0; i < dis.size(); i++){
ll maxlen = dis[i].first;
ll minlen = dis[i].second;
add(dis[i].second, sa[dis[i].second]);
ans = qadd(ans, sa[maxlen] * sa[minlen] % mod * qadd(i+1, -sum2(dis[i].second-1)) % mod);
ans = qadd(ans, sa[maxlen] * sum1(dis[i].second-1) % mod);
}
for(int i = 0; i < dis.size(); i++){
ll maxlen = dis[i].first;
ll minlen = dis[i].second;
sub(dis[i].second, sa[dis[i].second]);
}
return ans;
}
void divide(int u){
ans = qadd(ans, calc(u, w[u].rnk, w[u].rnk));
vis[u] = true;
for(int i = 0; i < mp[u].size(); i++){
int v = mp[u][i];
if(vis[v]) continue;
ans = qadd(ans, -calc(v, max(w[u].rnk, w[v].rnk), min(w[u].rnk, w[v].rnk)));
subn = siz[v];
root = 0;
getroot(v, 0);
divide(root);
}
return;
}
void init(){
subn = n;
ans = 0;
for(int i = 1; i <= n; i++){
vis[i] = false;
mp[i].clear();
}
}
int main(){
scanf("%d", &n);
init();
for(int i = 1; i <= n; i++){
scanf("%lld", &w[i].num);
w[i].place = i;
}
sort(w+1, w+n+1, comp);
for(int i = 1; i <= n; i++){
if(i == 1){
w[i].rnk = 1;
}else{
int r = w[i-1].rnk;
if(w[i-1].num != w[i].num) w[i].rnk = r+1;
else w[i].rnk = r;
}
sa[w[i].rnk] = w[i].num;
}
sort(w+1, w+n+1, comp1);
for(int i = 1; i < n; i++){
int u, v;
scanf("%d%d", &u, &v);
mp[u].push_back(v);
mp[v].push_back(u);
}
getroot(1, 0);
divide(root);
printf("%lld", ans);
return 0;
} </code></pre>