forked from LearningJournal/Spark-Programming-In-Python
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathDataLoader.py
More file actions
61 lines (49 loc) · 2.07 KB
/
DataLoader.py
File metadata and controls
61 lines (49 loc) · 2.07 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from lib import ConfigLoader
def get_account_schema():
schema = """load_date date,active_ind int,account_id string,
source_sys string,account_start_date timestamp,
legal_title_1 string,legal_title_2 string,
tax_id_type string,tax_id string,branch_code string,country string"""
return schema
def get_party_schema():
schema = """load_date date,account_id string,party_id string,
relation_type string,relation_start_date timestamp"""
return schema
def get_address_schema():
schema = """load_date date,party_id string,address_line_1 string,
address_line_2 string,city string,postal_code string,
country_of_address string,address_start_date date"""
return schema
def read_accounts(spark, env, enable_hive, hive_db):
runtime_filter = ConfigLoader.get_data_filter(env, "account.filter")
if enable_hive:
return spark.sql("select * from " + hive_db + ".accounts").where(runtime_filter)
else:
return spark.read \
.format("csv") \
.option("header", "true") \
.schema(get_account_schema()) \
.load("test_data/accounts/") \
.where(runtime_filter)
def read_parties(spark, env, enable_hive, hive_db):
runtime_filter = ConfigLoader.get_data_filter(env, "party.filter")
if enable_hive:
return spark.sql("select * from " + hive_db + ".parties").where(runtime_filter)
else:
return spark.read \
.format("csv") \
.option("header", "true") \
.schema(get_party_schema()) \
.load("test_data/parties/") \
.where(runtime_filter)
def read_address(spark, env, enable_hive, hive_db):
runtime_filter = ConfigLoader.get_data_filter(env, "address.filter")
if enable_hive:
return spark.sql("select * from " + hive_db + ".party_address").where(runtime_filter)
else:
return spark.read \
.format("csv") \
.option("header", "true") \
.schema(get_address_schema()) \
.load("test_data/party_address/") \
.where(runtime_filter)