610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728 | def build_rest_api_config(
spark: SparkSession,
source: Any,
*,
schema: Optional[StructType] = None,
source_config: Optional[Mapping[str, Any]] = None,
options: Optional[Mapping[str, Any]] = None,
headers: Optional[Mapping[str, str]] = None,
**kwargs: Any,
) -> Dict[str, Any]:
"""Build the options payload consumed by the REST data source."""
config: Dict[str, Any] = {}
for mapping in (source_config, options, kwargs):
if mapping:
config.update(mapping)
records_field = config.get("records_field")
if isinstance(records_field, str):
records_path = records_field.split(".") if records_field else None
elif isinstance(records_field, Sequence):
records_path = [str(part) for part in records_field]
elif records_field is None:
records_path = None
else:
raise TypeError("records_field must be a string or sequence")
infer_schema = bool(config.get("infer_schema", schema is None))
if not infer_schema and schema is None:
raise ValueError("schema must be provided when infer_schema=False for REST API reads")
request_timeout = float(config.get("request_timeout", 30.0))
max_retries = int(config.get("max_retries", 3))
backoff_factor = float(config.get("retry_backoff", 0.5))
base_headers: Dict[str, str] = {}
for header_map in (config.get("headers"), headers):
if isinstance(header_map, Mapping):
base_headers.update({str(k): str(v) for k, v in header_map.items()})
request_kwargs: Dict[str, Any] = {}
if isinstance(config.get("request_kwargs"), Mapping):
request_kwargs.update(config["request_kwargs"])
request_type = str(config.get("request_type", "GET")).upper()
if request_type not in {"GET", "POST"}:
raise ValueError("request_type must be either 'GET' or 'POST'")
request_body = config.get("request_body")
if request_body is not None and request_type != "POST":
raise ValueError("request_body is only supported when request_type='POST'")
if request_body is not None:
body_mode = config.get("request_body_type")
if body_mode is None:
body_mode = "json" if isinstance(request_body, Mapping) else "data"
body_mode = str(body_mode).lower()
if body_mode == "json":
request_kwargs.setdefault("json", request_body)
elif body_mode in {"data", "form"}:
request_kwargs.setdefault("data", request_body)
elif body_mode in {"raw", "content"}:
request_kwargs.setdefault("data", request_body)
else:
raise ValueError(
"request_body_type must be one of {'json', 'data', 'form', 'raw', 'content'}"
)
pagination = config.get("pagination")
if pagination is not None and not isinstance(pagination, Mapping):
raise TypeError("pagination configuration must be a mapping when provided")
params = (
dict(config.get("params", {}))
if isinstance(config.get("params"), Mapping)
else config.get("params", {})
)
if params and not isinstance(params, Mapping):
raise TypeError("params configuration must be a mapping if provided")
include_response_payload = bool(config.get("include_response_payload", False))
response_payload_field: Optional[str] = None
if include_response_payload:
response_payload_field = str(config.get("response_payload_field", "response_payload"))
if not response_payload_field:
raise ValueError("response_payload_field must be a non-empty string when enabled")
work_source: List[str]
if isinstance(source, str):
work_source = [source]
elif isinstance(source, Sequence) and not isinstance(source, (str, bytes)):
work_source = [str(url) for url in source]
else:
raise TypeError("source must be a string URL or a sequence of URLs for REST reads")
for url in work_source:
if not _validate_http_url(url):
raise ValueError(f"Invalid REST endpoint: {url}")
spark_parallelism = config.get("parallelism")
if spark_parallelism is None:
spark_parallelism = spark.sparkContext.defaultParallelism or 1
payload_config = {
"sources": work_source,
"params": params or {},
"pagination": pagination,
"records_field": records_path,
"request_type": request_type,
"request_kwargs": _normalize_jsonable(request_kwargs),
"headers": base_headers,
"timeout": request_timeout,
"max_retries": max_retries,
"backoff_factor": backoff_factor,
"include_response_payload": include_response_payload,
"response_payload_field": response_payload_field,
"parallelism": int(spark_parallelism),
"infer_schema": infer_schema,
}
return payload_config
|