BOJ 3783

 

https://www.acmicpc.net/problem/3783

알고리즘 문제 해결 전략 #13 에 나오는 bisection method 를 적용하기 위한 간단한 연습문제로 생각을 했으나,

소수 정밀도 때문에 bigint 자료형을 직접 만들어서 binary search 를 적용해서 풀었다.

bisection method 를 통해 정밀한 세제곱근을 구해서 푸는 방법은 소수 정밀도의 한계 때문에 힘들어 보인다.

SPOJ 의 같은 문제를 풀은 아래 링크들이나, BOJ 에서 해결한 사람들 모두 이분탐색을 많이 사용했다.

참고 링크 1 참고 링크 2

Source Code


//BOJ3783

#define _CRT_SECURE_NO_WARNINGS
#include <bits/stdc++.h>
using namespace std;

const int base = 100000;
int T;

struct bigInt {
#define ll long long
	int lastdigit;
	const ll base = 100000;
	// 5 digit chunk
	ll digits[505];

	bigInt() {
		lastdigit = 0;
		memset(digits, 0, sizeof(digits));
	}

	bigInt operator = (bigInt rhs) {
		lastdigit = rhs.lastdigit;
		for (int i = 0; i < 505; ++i) {
			digits[i] = rhs.digits[i];
		}
		return *this;
	}
};

int pow10(int n) {
	int ret = 1;
	for (int i = 0; i < n; ++i) {
		ret *= 10;
	}
	return ret;
}

void zero_justify(bigInt& n) {

	while (n.lastdigit > 0 && n.digits[n.lastdigit] == 0) {
		--(n.lastdigit);
	}
}


void init(const string& s, bigInt& n) {
	// get input. 5 digit chunks
	memset(n.digits, 0, sizeof(n.digits));
	int ssize = s.size();
	int idx, chunk;
	idx = 0;
	for (int i = 0; i < ssize; i += 5) {
		chunk = 0;
		for (int j = 0; j < 5; ++j) {
			int tidx = ssize - 1 - i - j;
			if (tidx >= 0) {
				chunk += (s[tidx] - '0') * pow10(j);
			}
		}
		n.digits[idx] = chunk;
		++idx;
	}

	n.lastdigit = idx;

	zero_justify(n);
}

int compare(bigInt& lhs, bigInt& rhs) {
	//lhs < rhs:1, lhs>rhs:-1, lhs == rhs:0

	if (lhs.lastdigit < rhs.lastdigit) {
		return 1;
	}
	if (lhs.lastdigit > rhs.lastdigit) {
		return -1;
	}
	int maxdigit = max(lhs.lastdigit, rhs.lastdigit);
	for (int i = maxdigit; i >= 0; --i) {
		if (lhs.digits[i] > rhs.digits[i]) {
			return -1;
		}
		if (lhs.digits[i] < rhs.digits[i]) {
			return 1;
		}
	}

	return 0;

}

void digitshift(bigInt& n, int d) {
	// mutiply 100000(base)
	if (n.lastdigit == 0 && n.digits[0] == 0) return;

	for (int i = n.lastdigit; i >= 0; --i) {
		n.digits[i + d] = n.digits[i];
	}

	for (int i = 0; i < d; ++i) {
		n.digits[i] = 0;
	}

	n.lastdigit += d;

	zero_justify(n);
	return;
}

bigInt operator + (bigInt lhs, bigInt rhs) {
	bigInt ret;

	ll carry = 0;

	ret.lastdigit = max(lhs.lastdigit, rhs.lastdigit) + 1;

	for (int i = 0; i <= ret.lastdigit; ++i) {
		ll tmp = lhs.digits[i] + rhs.digits[i] + carry;
		ret.digits[i] = tmp % lhs.base;
		carry = tmp / lhs.base;
	}

	zero_justify(ret);

	return ret;
}

bigInt operator + (bigInt lhs, ll rhs) {
	bigInt ret = lhs;

	ll carry = 0;

	ret.digits[0] += rhs;
	ret.lastdigit++;

	for (int i = 0; i <= ret.lastdigit; ++i) {
		ll tmp = ret.digits[i] + carry;
		ret.digits[i] = tmp % lhs.base;
		carry = tmp / lhs.base;
	}

	zero_justify(ret);

	return ret;
}

bigInt operator - (bigInt lhs, ll rhs) {
	bigInt ret = lhs;

	ll borrow = 0;

	ret.digits[0] -= rhs;

	for (int i = 0; i <= ret.lastdigit; ++i) {
		ll tmp = ret.digits[i] - borrow;
		if (ret.digits[i] > 0) {
			borrow = 0;
		}
		if (tmp < 0) {
			tmp += lhs.base;
			borrow = 1;
		}

		ret.digits[i] = tmp % lhs.base;
	}

	zero_justify(ret);

	return ret;
}

bigInt operator * (bigInt lhs, bigInt rhs) {
	bigInt ret;

	int ldigit = lhs.lastdigit, rdigit = rhs.lastdigit;

	for (int i = 0; i <= ldigit; ++i) {
		for (int j = 0; j <= rdigit; ++j) {
			ll tmp = lhs.digits[i] * rhs.digits[j];
			ret.digits[i + j] += tmp;
		}
	}

	ll carry = 0;
	ret.lastdigit = ldigit + rdigit + 1;

	for (int i = 0; i <= ret.lastdigit; ++i) {
		ll tmp = ret.digits[i] + carry;
		ret.digits[i] = tmp % lhs.base;
		carry = tmp / lhs.base;
	}

	zero_justify(ret);

	return ret;
}

bigInt operator / (bigInt lhs, long long rhs) {
	if (rhs == 0) return lhs;

	long long tmp, r = 0;
	bigInt q = lhs;
	for (int i = lhs.lastdigit; i >= 0; --i) {
		tmp = lhs.base * r + q.digits[i];
		q.digits[i] = tmp / rhs;
		r = tmp % rhs;
	}

	zero_justify(q);
	return q;
}

bool operator > (bigInt lhs, bigInt rhs) {
	if (compare(lhs, rhs) == -1) {
		return true;
	}
	else return false;
}

bool operator < (bigInt lhs, bigInt rhs) {
	if (compare(lhs, rhs) == 1) {
		return true;
	}
	else return false;
}

bool operator <= (bigInt lhs, bigInt rhs) {
	return !(rhs < lhs);
}

bool operator >= (bigInt lhs, bigInt rhs) {
	return !(rhs > lhs);
}

bool operator == (bigInt lhs, bigInt rhs) {
	if (compare(lhs, rhs) == 0) {
		return true;
	}
	else return false;
}

void print10(bigInt a) {
	//print the number with 10 decimal digits
	if (a.lastdigit < 20) {
		printf("0.");
		for (int i = 19; i >= 18; --i) printf("%05d", (int)a.digits[i]);
		return;
	}
	for (int i = a.lastdigit; i >= 18; --i) {
		if (i == a.lastdigit) printf("%d", (int)a.digits[i]);
		else printf("%05d", (int)a.digits[i]);
		if (i == 20) printf(".");
	}
}

void print(bigInt n) {

	printf("%d", (int)n.digits[n.lastdigit]);
	for (int i = n.lastdigit - 1; i >= 0; --i) printf("%05d", (int)n.digits[i]);
	return;
}

void proc() {
	//ifstream cin;
	//cin.open("input.txt");

	cin >> T;
	while (T--) {
		string s;
		cin >> s;
		bigInt lo, hi, target;

		init("0", lo);
		init(s, hi);

		digitshift(hi,60);
		//multiply hi large enough to guarantee the precision
		//shift 60 digits -> decimal separator starts from (60/3)

		target = hi;
		bigInt mid;
		//binary search
		while (lo <= hi) {
			mid = (lo + hi) / 2;
			
			if (mid * (mid * mid) > target)
				hi = mid - 1;
			else
				lo = mid + 1;
		}

		lo = lo - 1;

		print10(lo);

		printf("\n");
	}
}

void test() {
	bigInt a, b;
	init("100000", b);
	init("99999", a);

	bigInt c = a*b;
	print(c);
	return;
}

int main() {
	ios::sync_with_stdio(false);
	cin.tie(NULL);
	//test();
	proc();
	return 0;
}


아래는 내가 참고한 자료(위 참고링크 1). 해결 시간이 무려 4ms. 엄청나게 빠르다.


//Alfonso2 Peterssen
//SPOJ #291 "Cube Root"
//16 - 9 - 2008

#include <cstdio>
#include <algorithm>
#include <cstring>

using namespace std;

#define REP( i, n ) \
	for ( int i = 0; i < (n); i++ )
#define REPD( i, n ) \
	for ( int i = (n) - 1; i >= 0; i-- )

typedef long long int64;

const int
BASE = 10000000,
BASE_LOG = 7,
MAXLEN = 100;

int T;
struct bigint {
	int size;
	int64 T[MAXLEN];
	bigint(int64 x = 0, int size = 1) : size(size) {
		memset(T, 0, sizeof(T));
		T[0] = x;
	}
};

bool operator < (const bigint& a, const bigint& b) {
	if (a.size != b.size)
		return a.size < b.size;
	REPD(i, a.size)
		if (a.T[i] != b.T[i])
			return a.T[i] < b.T[i];
	return false;
}

bool operator <= (const bigint& a, const bigint& b) {
	return !(b < a);
}

bigint& normal(bigint& bn) {
	int64 r = 0, t;
	REP(i, bn.size) {
		while (bn.T[i] < 0) {
			bn.T[i + 1] -= 1;
			bn.T[i] += BASE;
		}
		t = bn.T[i] + r;
		bn.T[i] = t % BASE;
		r = t / BASE;
	}
	for (; r > 0; r /= BASE)
		bn.T[bn.size++] = r % BASE;
	while (bn.size > 1 && bn.T[bn.size - 1] == 0)
		bn.size--;
	return bn;
}

bigint operator + (bigint a, bigint& b) {
	a.size = max(a.size, b.size);
	REP(i, a.size)
		a.T[i] += b.T[i];
	return normal(a);
}

bigint operator + (bigint a, int64 x) {
	a.T[0] += x;
	return normal(a);
}

bigint operator - (bigint a, bigint& b) {
	a.size = max(a.size, b.size);
	REP(i, a.size)
		a.T[i] += b.T[i];
	return normal(a);
}

bigint operator - (bigint a, int64 x) {
	a.T[0] -= x;
	return normal(a);
}

bigint operator * (bigint a, bigint b) {
	bigint c(0, a.size + b.size);
	REP(i, a.size)
		REP(j, b.size)
		c.T[i + j] += a.T[i] * b.T[j];
	return normal(c);
}

bigint operator * (bigint a, int64 x) {
	REP(i, a.size)
		a.T[i] *= x;
	return normal(a);
}

pair< bigint, int64 > divmod(bigint bn, int64 x) {
	int64 r = 0, t;
	REPD(i, bn.size) {
		t = BASE * r + bn.T[i];
		bn.T[i] = t / x;
		r = t % x;
	}
	return make_pair(normal(bn), r);
}

bigint operator / (bigint bn, int64 x) {
	return divmod(bn, x).first;
}

int64 operator % (bigint bn, int64 x) {
	return divmod(bn, x).second;
}

void print(bigint bn) {
	printf("%lld", bn.T[bn.size - 1]);
	REPD(i, bn.size - 1)
		printf("%0*lld", BASE_LOG, bn.T[i]);
}

void read(bigint& bn) {
	static char line[200];
	scanf("%s", &line);
	int len = strlen(line);
	int i = len % BASE_LOG;
	if (i > 0) i -= BASE_LOG;
	for (; i < len; i += BASE_LOG) {
		int64 x = 0;
		REP(j, BASE_LOG)
			x = 10 * x + (i + j >= 0 ? line[i + j] - '0' : 0);
		bn.T[bn.size++] = x;
	}
	reverse(bn.T, bn.T + bn.size);
	normal(bn);
}

int main() {

	for (scanf("%d", &T); T--; ) {
		bigint X;
		read(X);

		REP(i, 15) X = X * 100;

		bigint lo(0);
		bigint hi = X;

		while (lo <= hi) {
			bigint mid = (lo + hi) / 2;
			if (mid * mid * mid <= X)
				lo = mid + 1;
			else
				hi = mid - 1;
		}

		lo = lo - 1;

		pair< bigint, int64 > R;
		R = divmod(lo, 100000);
		int p1 = R.second; lo = R.first;
		R = divmod(lo, 100000);
		int p2 = R.second; lo = R.first;

	    print(lo); printf(".%05d%05d\n", p2, p1);
	}

	return 0;
}

wbjeon2k

Pursuite for Progress.

This work is licensed under a Attribution-NonCommercial 4.0 International license. Attribution-NonCommercial 4.0 International