JustPaste.it
import unittest
from unittest.mock import patch, MagicMock

from pyspark.sql import functions as f, DataFrame, SparkSession
from pyspark_test import assert_pyspark_df_equal


class ClassToTest:
def __init__(self) -> None:
self.spark = SparkSession.builder.getOrCreate()

def do_stuff(self, df1: DataFrame, df2_path: str, df1_key: str, df2_key: str) -> DataFrame:
df2 = self.spark.read.format('parquet').load(df2_path)
return df1.join(df2, [f.col(df1_key) == f.col(df2_key)], 'left')


class TestClassToTest(unittest.TestCase):
def setUp(self) -> None:
self.spark = SparkSession.builder.getOrCreate()

@patch("class_to_test.SparkSession")
def test_do_stuff(self, mock_spark: MagicMock) -> None:
spark = MagicMock()
spark.read.return_value.format.return_value.load.return_value = \
self.spark.createDataFrame([(1, 2)], ["key2", "c2"])
mock_spark.return_value = spark

input_df = self.spark.createDataFrame([(1, 1)], ["key1", "c1"])
actual_df = ClassToTest().do_stuff(input_df, "df2", "key1", "key2")
expected_df = self.spark.createDataFrame(
[(1, 1, 1, 2)],
["key1", "c1", "key2", "c2"])
assert_pyspark_df_equal(actual_df, expected_df)


if __name__ == "__main__":
unittest.main()
# ctt = ClassToTest()
# ctt.spark.createDataFrame([(1, 2)], ["key2", "c2"]).write.format("parquet").mode("overwrite").save("df2")
# df1 = ctt.spark.createDataFrame([(1, 1)], ["key1", "c1"])
# res = ctt.do_stuff(df1, "df2", "key1", "key2")
# res.show()