点分治

下面是求所有路径最大值乘最小值之和的(路径两端点可以相同)

#include <iostream>
#include <vector>
#include <algorithm>
//#include <unordered_map>
using namespace std;
typedef long long ll;
typedef pair<ll, ll> Node;
typedef pair<ll, int> Tree;
struct Num{
    ll num;
    int place;
    int rnk;
};
const int maxn = 1e5+5;
const ll mod = 998244353;
//unordered_map<int, int> rnk;
vector<int> mp[maxn];
int n, k;
Num w[maxn];
ll sa[maxn]; 
//siz表示子树大小,sonsiz表示重儿子子树的大小 
int siz[maxn], sonsiz[maxn] = {maxn}, root;
bool vis[maxn];
vector<Node> dis;//存放中间内容 
int subn;//求根时子树大小 
ll ans = 0;

ll qadd(ll a, ll b){
    ll sum = a + b;
    while(sum < 0)    sum += mod;
    while(sum >= mod)    sum -= mod;
    return sum;
}

Tree tree[maxn];
int lowerbit(int x){
    return x & -x;
}
void add(int z, ll x){
    for(int i = z; i <= n; i = i + lowerbit(i)){
        tree[i].first = qadd(tree[i].first, x);
        tree[i].second += 1;
    }
}
void sub(int z, ll x){
    for(int i = z; i <= n; i = i + lowerbit(i)){
        tree[i].first = qadd(tree[i].first, -x);
        tree[i].second -= 1;
    }
}
ll sum1(int z){
    ll sum = 0;
    for(int i = z; i > 0; i = i - lowerbit(i)){
        sum = qadd(sum, tree[i].first);
    }
    return sum;
}
int sum2(int z){
    int sum = 0;
    for(int i = z; i > 0; i = i - lowerbit(i)){
        sum += tree[i].second;
    }
    return sum;
}

void getroot(int u, int fa){//n为树的大小,求重心 
    siz[u] = 1; 
    sonsiz[u] = 0;
    for(int i = 0; i < mp[u].size(); i++){
        int v = mp[u][i];
        if(vis[v] || v == fa)    continue;
        getroot(v, u);
        siz[u] = siz[u] + siz[v];
        sonsiz[u] = max(sonsiz[u], siz[v]);
    }
    sonsiz[u] = max(sonsiz[u], subn - siz[u]);
    if(sonsiz[u] < sonsiz[root])    root = u;
}
void getdis(int u, int fa, int maxdist, int mindist){
    dis.push_back(Node(maxdist, mindist));
    for(int i = 0; i < mp[u].size(); i++){
        int v = mp[u][i];
        if(vis[v] || v == fa)    continue;
        getdis(v, u, max(maxdist, w[v].rnk), min(mindist, w[v].rnk));
    }
    return;
}
bool comp(Num a, Num b){
    return a.num < b.num;
}
bool comp1(Num a, Num b){
    return a.place < b.place;
}
bool comp3(Node a, Node b){
    return a.second < b.second;
}

ll calc(int u, int maxlen, int minlen){ //这个用来求解的函数,自定义 
    ll ans = 0;
    dis.clear();
    getdis(u, 0, maxlen, minlen);
    sort(dis.begin(), dis.end());
    for(int i = 0; i < dis.size(); i++){
        ll maxlen = dis[i].first;
        ll minlen = dis[i].second;
        add(dis[i].second, sa[dis[i].second]);
        ans = qadd(ans, sa[maxlen] * sa[minlen] % mod * qadd(i+1, -sum2(dis[i].second-1)) % mod);
        ans = qadd(ans, sa[maxlen] * sum1(dis[i].second-1) % mod);
    }
    for(int i = 0; i < dis.size(); i++){
        ll maxlen = dis[i].first;
        ll minlen = dis[i].second;
        sub(dis[i].second, sa[dis[i].second]);
    }
      return ans;
}
void divide(int u){
    ans = qadd(ans, calc(u, w[u].rnk, w[u].rnk));
    vis[u] = true;
    for(int i = 0; i < mp[u].size(); i++){
        int v = mp[u][i];
        if(vis[v])    continue;
        ans = qadd(ans, -calc(v, max(w[u].rnk, w[v].rnk), min(w[u].rnk, w[v].rnk)));
        subn = siz[v];
        root = 0;
        getroot(v, 0);
        divide(root);
    }
      return;
}
void init(){
    subn = n;
    ans = 0;
    for(int i = 1; i <= n; i++){
        vis[i] = false; 
        mp[i].clear();
    }
}
int main(){
    scanf("%d", &n);
    init();
    for(int i = 1; i <= n; i++){
        scanf("%lld", &w[i].num);
        w[i].place = i;
    }
    sort(w+1, w+n+1, comp);
    for(int i = 1; i <= n; i++){
        if(i == 1){
            w[i].rnk = 1;
        }else{
            int r = w[i-1].rnk;
            if(w[i-1].num != w[i].num)    w[i].rnk = r+1;
            else    w[i].rnk = r;
        }
        sa[w[i].rnk] = w[i].num;
    }
    sort(w+1, w+n+1, comp1);
    for(int i = 1; i < n; i++){
        int u, v;
        scanf("%d%d", &u, &v);
        mp[u].push_back(v);
        mp[v].push_back(u);
    }
    getroot(1, 0);
    divide(root);
    printf("%lld", ans);
    return 0;

}