1901 - Median Heaps

De la Universitas MediaWiki
Versiunea din 18 mai 2024 13:50, autor: Oros Ioana Diana (discuție | contribuții)
(dif) ← Versiunea anterioară | Versiunea curentă (dif) | Versiunea următoare → (dif)

Cerința

Se dă un vector de N numere naturale nenule, indexat de la 1.

Se cere să se raspundă la Q interogări de tipul:

  • pentru un interval [l, r] din vector, aflați costul total mimin, al egalizării tuturor elementelor din interval.

Într-un interval [l, r], puteți crește sau micșora fiecare element cu costul x unde x este diferența dintre valoarea nouă și valoarea inițială. Costul total este suma acestor costuri.

Date de intrare

  • pe prima linie numărul N.
  • pe a doua linie N numere naturale nenule : , , … .
  • pe a treia linie numărul Q.
  • pe următoarele Q linii 2 numere: l, r.

Date de ieșire

Se vor afișa Q numere pe fiecare linie, reprezentând constul total minim, al fiecărui interval l r.

Restricții și precizări

  • 1 ≤ N, Q ≤ 100.000
  • 1 ≤ l ≤ r ≤ 100.000
  • 1 ≤  ≤ 1.000.000.000

Exemplu:

Intrare

9
3 2 16 15 12 6 20 4 5
3
2 7
2 2
3 8

Ieșire

31
0
29

Explicație

Pentru intervalul [2, 7], costul total minim este 31, deoarece egalizăm fiecare număr din interval cu 12.

Pentru intervalul [2, 2], costul total minim este 0, deoarece avem un singur element în interval.

Pentru intervalul [3, 8], costul total minim este 29, deoarece egalizăm fiecare număr din interval cu 12

Rezolvare

import math

Nmax = 100001

class Query:
    def __init__(self, st, dr, pos):
        self.st = st
        self.dr = dr
        self.pos = pos

def update(aib, aibs, pos, val1, val2):
    while pos < Nmax:
        aib[pos] += val1
        aibs[pos] += val2
        pos += pos & -pos

def query(aib, aibs, pos, c):
    sum_val = 0
    s = 0
    while pos > 0:
        sum_val += aib[pos]
        if aibs is not None:
            s += aibs[pos]
        pos -= pos & -pos
    return sum_val if c == 1 else s

def add(k, N, V, S, aib, aibs):
    update(aib, aibs, N[k], 1, V[S[N[k]]])

def delete(k, N, V, S, aib, aibs):
    update(aib, aibs, N[k], -1, -V[S[N[k]]])

def binary_search(pos, n, aib):
    st, dr = 1, n
    ans = -1
    while st <= dr:
        mid = (st + dr) // 2
        q = query(aib, None, mid, 1)
        if q >= pos:
            dr = mid - 1
            ans = mid
        else:
            st = mid + 1
    return ans

def verify_restrictions(n, q, ranges):
    if not (1 <= n <= 100000 and 1 <= q <= 100000):
        return False
    for l, r in ranges:
        if not (1 <= l <= r <= 100000):
            return False
    return True

def main():
    n = int(input())
    V = [0] * (n + 1)
    
    values = list(map(int, input().split()))
    for i in range(1, n + 1):
        V[i] = values[i-1]
    
    q = int(input())
    ranges = []
    Q = []
    for i in range(1, q + 1):
        st, dr = map(int, input().split())
        ranges.append((st, dr))
        Q.append(Query(st, dr, i))

    if not verify_restrictions(n, q, ranges):
        print("Datele nu corespund restrictiilor impuse")
        return

    block = int(math.sqrt(n))
    S = list(range(n + 1))
    S[1:] = sorted(S[1:], key=lambda x: V[x])
    N = [0] * (n + 1)
    for i in range(1, n + 1):
        N[S[i]] = i

    Q.sort(key=lambda x: (x.dr // block, x.st if x.dr // block % 2 == 0 else -x.st))
    aib = [0] * Nmax
    aibs = [0] * Nmax
    Rez = [0] * (q + 1)
    
    st, dr = 1, 0
    for i in range(1, q + 1):
        s, d = Q[i-1].st, Q[i-1].dr
        while st < s:
            delete(st, N, V, S, aib, aibs)
            st += 1
        while st > s:
            st -= 1
            add(st, N, V, S, aib, aibs)
        while dr < d:
            dr += 1
            add(dr, N, V, S, aib, aibs)
        while dr > d:
            delete(dr, N, V, S, aib, aibs)
            dr -= 1

        poss = (d - s + 2) // 2
        ans = binary_search(poss, n, aib)
        s1 = query(aib, aibs, n, 2)
        h1 = query(aib, None, n, 1)
        s2 = query(aib, aibs, ans, 2)
        h2 = query(aib, None, ans, 1)
        s1 -= s2
        h1 -= h2

        Rez[Q[i-1].pos] = s1 - h1 * V[S[ans]] + (-s2 + h2 * V[S[ans]])

    for i in range(1, q + 1):
        print(Rez[i])

if __name__ == "__main__":
    main()