1195 - NMult

De la Universitas MediaWiki
Versiunea pentru tipărire nu mai este suportată și poate avea erori de randare. Vă rugăm să vă actualizați bookmarkurile browserului și să folosiți funcția implicită de tipărire a browserului.

Enunț

Se consideră trei numere naturale nenule n, k și w.

Cerința

Să se scrie un program care determină numărul m al mulțimilor de forma {x[1], x[2], … ,x[k]} având ca elemente numere naturale nenule, ce satisfac simultan condițiile:

1 ≤ x[1] < x[2] < ... < x[k] ≤ n x[i+1] - x[i] ≥ w, 1 ≤ i ≤ k - 1

Date de intrare

Fișierul de intrare nmult.in conține pe prima linie trei numere naturale nenule n, k, w separate prin câte un spaţiu, cu semnificaţia de mai sus.

Date de ieșire

Fișierul de ieșire nmult.out va conține pe prima linie restul împărţirii numărului m la 666013.

Restricții și precizări

1 ≤ n, k, w ≤ 1.000.000;

Exemplu

Exemplu1

nmult.in

5 2 2

nmult.out

6

Exemplu2

nmult.in

10 3 4

nmult.out

4

Exemplu3

nmult.in

10 4 4

nmult.out

0

Rezolvare

 
def validare(n, k, w):
    if not(1 <= n <= 1000000 and 1 <= k <= 1000000 and 1 <= w <= 1000000):
        return False
    return True

def rezolvare(n, k, w):
    MOD = 666013
    # Cazul de bază
    if k == 1:
        return n

    # Cazul în care diferența minimă între elemente este mai mare decât n
    if w >= n:
        return 0

    # Aplicăm formula de recursivitate
    dp = [[0 for _ in range(k + 1)] for _ in range(n + 1)]
    for i in range(1, n + 1):
        dp[i][1] = 1
    for j in range(2, k + 1):
        for i in range(1, n + 1):
            dp[i][j] = (dp[i - 1][j] + dp[max(0, i - w)][j - 1]) % MOD

    # Suma soluțiilor pentru mulțimile de k elemente
    sol = 0
    for i in range(1, n + 1):
        sol = (sol + dp[i][k]) % MOD

    return sol

def main():
    with open("nmult.in", "r") as fin:
        n, k, w = map(int, fin.readline().split())

    if not validare(n, k, w):
        print("Date de intrare invalide")
        return

    m = rezolvare(n, k, w)

    with open("nmult.out", "w") as fout:
        fout.write(str(m) + "\n")

if __name__ == "__main__":
    main()

Explicații

Codul are trei funcții principale:

Funcția validare primește parametrii n, k și w și verifică dacă aceștia sunt în intervalul valid (între 1 și 1000000). Dacă parametrii nu sunt în intervalul valid, se returnează False, altfel se returnează True.

Funcția rezolvare primește parametrii n, k și w. În primul rând, se verifică cazurile speciale. Dacă k == 1, atunci numărul de mulțimi posibile este n. Dacă w >= n, atunci nu există nicio mulțime de k elemente cu diferența între oricare 2 termeni consecutivi mai mare sau egală cu w, așadar numărul de mulțimi posibile este 0.

În cazul general, se aplică formula de recursivitate pentru a calcula numărul de mulțimi posibile. Mai precis, se calculează matricea dp, unde dp[i][j] reprezintă numărul de mulțimi de j elemente care se termină cu numărul i și ale căror diferențe între oricare 2 termeni consecutivi sunt cel puțin w. Acest lucru se realizează cu ajutorul formulei: dp[i][j] = dp[i-1][j] + dp[max(0, i-w)][j-1].

În cele din urmă, se calculează suma soluțiilor pentru toate mulțimile de k elemente, care încep cu elementul 1 până la elementul n-w+1.

Funcția main citește datele de intrare din fișierul nmult.in și le validează cu ajutorul funcției validare. Apoi, calculează soluția problemei cu ajutorul funcției rezolvare și o scrie în fișierul nmult.out.