import csv, sqlite3, operator, math

def main():
	with sqlite3.connect('hospitalData.sqlite') as conn:
		with open('Medicare_Provider_Charge_Inpatient_DRG100_FY2011.csv') as csvfile:
			createTables(conn)
			reader = csv.reader(csvfile, skipinitialspace=True)
			fieldnames = [x.strip() for x in reader.next()]
			i = 0
			providerIds = set()
			dischargeIds = {}
			for row in reader:
				providerId = row[1]
				if providerId not in providerIds:
					addProvider(conn, row)
					providerIds.add(providerId)
				dischargeId = int(row[0][0:row[0].find(' ')], 10)
				if dischargeId not in dischargeIds:
					dischargeDescription = row[0][row[0].find('-')+2:].strip()
					rowId = addDischarge(conn, dischargeId, dischargeDescription)
					dischargeIds[dischargeId] = rowId
				addPrice(conn, providerId, dischargeIds[dischargeId], int(row[8], 10), row[9], row[10])
				i += 1
		with open('Medicare_Charge_Inpatient_DRG100_DRG_Summary_by_DRG_FY2011.csv') as csvfile:
			reader = csv.reader(csvfile, skipinitialspace=True)
			for row in reader:
				if row[0][0] >= '0' and row[0][0] <= '9':
					dischargeId = int(row[0][0:row[0].find(' ')], 10)
					addAverage(conn, dischargeIds[dischargeId], float(row[2]), float(row[3]), int(row[1], 10))

		#calculateAverages(conn, dischargeIds)
		calculatePriceIndices(conn, providerIds, [dischargeIds[x] for x in dischargeIds])
		generateGeocodeInput(conn)
		addGeocodeData(conn)
		print fieldnames
		print i
		print len(providerIds)

def createTables(conn):
	conn.execute('''DROP TABLE IF EXISTS hospitals''')
	conn.execute('''DROP TABLE IF EXISTS discharges''')
	conn.execute('''DROP TABLE IF EXISTS prices''')
	conn.execute('''DROP TABLE IF EXISTS averagePrices''')
	conn.execute('''DROP TABLE IF EXISTS hospitalPriceIndices''')
	conn.execute('''VACUUM''')

	conn.execute('''CREATE TABLE hospitals (Id INTEGER PRIMARY KEY, Name TEXT, Address TEXT, City TEXT, State TEXT, ZipCode TEXT, ReferralRegion TEXT, Latitude TEXT, Longitude TEXT)''')
	conn.execute('''CREATE TABLE discharges (Id INTEGER PRIMARY KEY, DischargeId INTEGER, Description TEXT, IsInpatient BOOLEAN)''')
	conn.execute('''CREATE TABLE prices (HospitalId INTEGER, DischargeId INTEGER, Count INTEGER, CoveredCharges REAL, TotalPayments REAL, PRIMARY KEY (HospitalId, DischargeId), FOREIGN KEY(HospitalId) REFERENCES hospitals(Id), FOREIGN KEY(DischargeId) REFERENCES discharges(Id))''')
	conn.execute('''CREATE TABLE averagePrices (DischargeId INTEGER PRIMARY KEY, CoveredCharges REAL, TotalPayments REAL, NumberOfDischarges INTEGER, FOREIGN KEY(DischargeId) REFERENCES discharges(Id))''')
	conn.execute('''CREATE TABLE hospitalPriceIndices (HospitalId INTEGER, CoveredCharges REAL, TotalPayments REAL, NumberOfDischarges INTEGER, IsInpatient BOOLEAN, FOREIGN KEY(HospitalId) REFERENCES hospitals(Id))''')

def correctCase(word, providerId):
	# done with 3200
	if word == 'ST':
		return 'St.'
	if word == 'JR':
		return 'Jr.'
	if word == 'FT':
		return 'Ft.'
	if word == 'AND':
		return 'and'
	if word == 'OF':
		return 'of'
	if word == 'THE' and providerId != 450862:
		return 'the'
	if word == 'FOR':
		return 'for'
	if word == 'AN':
		return 'an'
	if word == 'ON':
		return 'on'
	if word == 'AT':
		return 'at'
	if word == 'UAB':
		return 'UAB'
	if word == 'UAMS':
		return 'UAMS'
	if word == 'UCLA':
		return 'UCLA'
	if word == 'UCSF':
		return 'UCSF'
	if word == 'USC':
		return 'USC'
	if word == 'UNM':
		return 'UNM'
	if word == 'UH':
		return 'UH'
	if word == 'UT':
		return 'UT'
	if word == 'UW':
		return 'UW'
	if word == 'USD':
		return 'USD'
	if word == 'UMDNJ':
		return 'UMDNJ'
	if word == 'LSU':
		return 'LSU'
	if word == 'UMASS':
		return 'UMass'
	if word == 'UVA':
		return 'UVA'
	if word == 'UHHS':
		return 'UHHS'
	if word == 'USMD':
		return 'USMD'
	if word == 'AHMC':
		return 'AHMC'
	if word == 'ACMH':
		return 'ACMH'
	if word == 'UPMC':
		return 'UPMC'
	if word == 'ETMC':
		return 'ETMC'
	if word == 'TOPS':
		return 'TOPS'
	if word == 'NW':
		return 'NW'
	if word == 'LA' and providerId != 140065 and providerId != 150006:
		return 'LA'
	if word == 'MS' and providerId == 250097:
		return 'MS'
	if word == 'NH':
		return 'NH'
	if word == 'NY':
		return 'NY'
	if word == 'PA':
		return 'PA'
	if word == 'RI':
		return 'RI'
	if word == 'LAC':
		return 'LAC'
	if word == 'CA':
		return 'CA'
	if word == 'KS':
		return 'KS'
	if word == 'WI':
		return 'WI'
	if word == 'GA':
		return 'GA'
	if word == 'TN':
		return 'TN'
	if word == 'IC':
		return 'IC'
	if word == 'PHS':
		return 'PHS'
	if word == 'FW':
		return 'FW'
	if word == 'O\'CONNOR':
		return 'O\'Connor'
	if word == 'O\'BLENESS':
		return 'O\'Bleness'
	if word == 'MED':
		return 'Medical'
	if word == 'MEM':
		return 'Memorial'
	if word == 'CTR':
		return 'Center'
	if word == 'CEN':
		return 'Center'
	if word == 'MEDCTR':
		return 'Medical Center'
	if word == 'SO':
		return 'South'
	if word == 'HLTH':
		return 'Health'
	if word == 'HLTHCR':
		return 'Healthcare'
	if word == 'HLTHCARE':
		return 'Healthcare'
	if word == 'LUTH':
		return 'Lutheran'
	if word == 'MT':
		return 'Mt.'
	if word == 'LP':
		return 'LP'
	if word == 'JFK':
		return 'JFK'
	if word == 'LLC':
		return 'LLC'
	if word == 'LLP':
		return 'LLP'
	if word == 'NEA':
		return 'NEA'
	if word == 'RMC':
		return 'RMC'
	if word == 'OSS':
		return 'OSS'
	if word == 'CGH':
		return 'CGH'
	if word == 'RHC':
		return 'RHC'
	if word == 'FHN':
		return 'FHN'
	if word == 'ARH':
		return 'ARH'
	if word == 'OCH':
		return 'OCH'
	if word == 'SSM':
		return 'SSM'
	if word == 'UMC':
		return 'UMC'
	if word == 'EMH':
		return 'EMH'
	if word == 'TRMC':
		return 'TRMC'
	if word == 'VHS':
		return 'VHS'
	if word == 'JPS':
		return 'JPS'
	if word == 'LDS':
		return 'LDS'
	if word == 'CJW':
		return 'CJW'
	if word == 'LUKES':
		return 'Luke\'s'
	if word == 'ANTHONYS':
		return 'Anthony\'s'
	if word == 'MARYS':
		return 'Mary\'s'
	if word == 'JOHNS' and providerId != 210009 and providerId != 210029:
		return 'John\'s'
	if word == 'JOSEPHS':
		return 'Joseph\'s'
	if word == 'VINCENTS':
		return 'Vincent\'s'
	if word == 'MARGARETS':
		return 'Margaret\'s'
	if word == 'DAVIDS':
		return 'David\'s'
	if word == 'MARKS':
		return 'Mark\'s'
	if word == 'UNIV':
		return 'University'
	if word == 'REG':
		return 'Regional'
	if word == 'HOSP':
		return 'Hospital'
	if word == 'HSPTL':
		return 'Hospital'
	if word == 'SURGCL':
		return 'Surgical'
	if word == 'MCMILLAN':
		return 'McMillan'
	if word == 'MACNEAL':
		return 'MacNeal'
	if word == 'MCDOWELL':
		return 'McDowell'
	if word == 'MCCREADY':
		return 'McCready'
	if word == 'MCKINLEY':
		return 'McKinley'
	if word == 'MCCULLOUGH':
		return 'McCullough'
	if word == 'MCALESTER':
		return 'McAlester'
	if word == 'MCCURTAIN':
		return 'McCurtain'
	if word == 'MCLAREN':
		return 'McLaren'
	if word == 'MCBRIDE':
		return 'McBride'
	if word == 'MCLEOD':
		return 'McLeod'
	if word == 'MCKENNAN':
		return 'McKennan'
	if word == 'MCNAIRY':
		return 'McNairy'
	if word == 'MCKENZIE':
		return 'McKenzie'
	if word == 'MCKINNEY':
		return 'McKinney'
	if word == 'MCKAY':
		return 'McKay'
	if word == 'DESOTO':
		return 'DeSoto'
	if word == 'DEKALB':
		return 'DeKalb'
	if word == 'HEALTHALLIANCE':
		return 'HealthAlliance'
	if word == 'FIRSTHEALTH':
		return 'FirstHealth'
	if word == 'ANMED':
		return 'AnMed'
	if word == 'ADCARE':
		return 'AdCare'
	if word == 'MEDWEST':
		return 'MedWest'
	if word == 'MEDCENTRAL':
		return 'MedCentral'
	if word == 'MIDMICHIGAN':
		return 'MidMichigan'
	if word == 'LIBERTYHEALTH':
		return 'LibertyHealth'
	if word == 'CAREPLEX':
		return 'CarePlex'
	if word == 'PEACHEALTH' or word == 'PEACEHEALTH':
		return 'PeaceHealth'
	if word == 'DEPAUL':
		return 'DePaul'

	if word.find('-') != -1:
		return '-'.join([correctCase(x, providerId) for x in word.split('-')])
	if word.find('/') != -1:
		return '/'.join([correctCase(x, providerId) for x in word.split('/')])
	if word.find('+') != -1:
		return '+'.join([correctCase(x, providerId) for x in word.split('+')])
	if word.find('&') != -1:
		if len(word) == 1:
			return '&'
		return ' & '.join([correctCase(x, providerId) for x in word.split('&')])
		#return '+'.join([correctCase(x, providerId) for x in word.split('+')])

	if len(word) == 1 and 'A' <= word[0] and word[0] <= 'Z' and providerId != 190246 and providerId != 270074 and providerId != 320013 and providerId != 350063 and providerId != 370093 and providerId != 370222 and providerId != 370228:
		return word + '.'
	if len(word) <= 1:
		return word
	if word[0] == '(':
		return word[0:2] + word[2:].lower()
	return word[0] + word[1:].lower()

def fixUpName(name, providerId):
	startsWithThe = False
	name = name.strip()
	if name.lower().endswith(', the'):
		name = name[:-5]
		startsWithThe = True
	if name.lower().endswith(',the'):
		name = name[:-4]
		startsWithThe = True
	if name.lower().endswith(' the'):
		name = name[:-4]
		startsWithThe = True
	if name.lower().startswith('the '):
		name = name[4:]
		startsWithThe = True
	parts = [correctCase(x, providerId) for x in name.split(' ') if x.strip() != '']
	if startsWithThe:
		parts = ['The'] + parts
	return ' '.join(parts)

def addProvider(conn, row):
	providerId = int(row[1])
	conn.execute('''INSERT INTO hospitals (Id, Name, Address, City, State, ZipCode, ReferralRegion) VALUES (?, ?, ?, ?, ?, ?, ?)''', (providerId, fixUpName(row[2], providerId), row[3], row[4], row[5], '0'*(5-len(row[6]))+row[6], row[7]))

def addDischarge(conn, id, description):
	conn.execute('''INSERT INTO discharges (DischargeId, Description, IsInpatient) VALUES (?, ?, ?)''', (id, description, True))
	conn.commit()
	# return the id
	return conn.execute('''SELECT Id FROM discharges WHERE DischargeId=? AND IsInpatient=?''', (id, True)).fetchone()[0]

def addPrice(conn, hospitalId, dischargeId, count, coveredCharges, totalPayments):
	conn.execute('''INSERT INTO prices (HospitalId, DischargeId, Count, CoveredCharges, TotalPayments) VALUES (?, ?, ?, ?, ?)''', (hospitalId, dischargeId, count, coveredCharges, totalPayments))

def addAverage(conn, dischargeId, coveredCharges, totalPayments, numberOfDischarges):
	conn.execute('''INSERT INTO averagePrices (DischargeId, CoveredCharges, TotalPayments, NumberOfDischarges) VALUES (?, ?, ?, ?)''', (dischargeId, coveredCharges, totalPayments, numberOfDischarges))

def calculateAverages(conn, dischargeIds):
	for id in dischargeIds:
		count = 0
		totalCovered = 0.0
		totalTotal = 0.0
		for row in conn.execute('''SELECT CoveredCharges, TotalPayments FROM prices WHERE DischargeId=?''', (id,)):
			count += 1
			totalCovered += float(row[0])
			totalTotal += float(row[1])
		conn.execute('''INSERT INTO averagePrices (DischargeId, CoveredCharges, TotalPayments) VALUES (?, ?, ?)''', (id, totalCovered/count, totalTotal/count))

def calculatePriceIndices(conn, providerIds, dischargeIds):
	# We're choosing to weight each procedure equally, instead of weighting
	# by price.
	for id in providerIds:
		# this is not particularly efficient
		coveredChargesIndices = []
		totalPaymentsIndices = []
		for row in conn.execute('''SELECT DischargeId, CoveredCharges, TotalPayments FROM prices WHERE HospitalId=?''', (id,)):
			dischargeId = int(row[0])
			if dischargeId in dischargeIds:
				coveredCharges = float(row[1])
				totalPayments = float(row[2])
				(coveredChargesAverage, totalPaymentsAverage) = [float(x) for x in conn.execute('''SELECT CoveredCharges, TotalPayments FROM averagePrices WHERE DischargeId=?''', (dischargeId,)).fetchone()]
				coveredChargesIndices.append(coveredCharges/coveredChargesAverage)
				totalPaymentsIndices.append(totalPayments/totalPaymentsAverage)
		coveredChargeIndex = 100 * (math.pow(reduce(operator.mul, coveredChargesIndices, 1), 1.0/len(coveredChargesIndices)))
		totalPaymentsIndex = 100 * (math.pow(reduce(operator.mul, totalPaymentsIndices, 1), 1.0/len(totalPaymentsIndices)))
		conn.execute('''INSERT INTO hospitalPriceIndices (HospitalId, CoveredCharges, TotalPayments, NumberOfDischarges, IsInpatient) VALUES (?, ?, ?, ?, ?)''', (id, coveredChargeIndex, totalPaymentsIndex, len(coveredChargesIndices), True))
 

def generateGeocodeInput(conn):
	with open('geocode_input.txt', 'w') as out:
		out.write('Bing Spatial Data Services, 2.0\n')
		out.write('Id| GeocodeRequest/Culture| GeocodeRequest/Address/AddressLine| GeocodeRequest/Address/District| GeocodeRequest/Address/AdminDistrict| GeocodeRequest/Address/PostalCode| GeocodeRequest/Address/CountryRegion| GeocodeResponse/Point/Latitude| GeocodeResponse/Point/Longitude| StatusCode| FaultReason\n')
		for row in conn.execute('''SELECT Id, Address, City, State, ZipCode FROM hospitals'''):
			out.write('%s|en-US|%s|%s|%s|%s|US||||\n' % (row[0], row[1], row[2], row[3], row[4]))

def addGeocodeData(conn):
	import geocodeOverrides
	idMap = geocodeOverrides.idToGeocode
	overrides = 0
	try:
		geocodes = 0
		with open('geocode_output.txt', 'r') as geo:
			for row in geo:
				parts = row.split('|')
				try:
					providerId = int(parts[0])
					latitude = parts[7]
					longitude = parts[8]
					if providerId in idMap:
						overrides += 1
						latitude = idMap[providerId][0]
						longitude = idMap[providerId][1]
					conn.execute('''UPDATE hospitals SET Latitude=?, Longitude=? WHERE id=?''', (latitude, longitude, providerId))
					geocodes += 1
				except:
					pass
	except:
		pass
	print "Wrote %d locations" % geocodes
	print "Got %d overrides" % overrides

if (__name__ == '__main__'):
	main()
